Commit 0bd6d2ce authored by rocking's avatar rocking
Browse files

Refactor reference code for different dimension

parent da0ddd48
...@@ -58,165 +58,162 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -58,165 +58,162 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{ {
using Argument = ReferenceAvgPoolBwd::Argument; using Argument = ReferenceAvgPoolBwd::Argument;
float Run(const Argument& arg) template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 1, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{ {
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 && auto f_ncw = [&](auto n, auto c, auto wi) {
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2)) std::size_t X = arg.window_spatial_lengths_[0];
{ std::size_t Wo = arg.doutput_.GetLengths()[2];
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1) float v_acc = 0;
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0; for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
for(std::size_t x = 0; x < X; ++x) if(w_tmp % arg.window_strides_[0] == 0)
{ {
auto w_tmp = static_cast<ck::long_index_t>(wi) + auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.window_strides_[0]);
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
if(w_tmp % arg.window_strides_[0] == 0) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
auto wo = static_cast<ck::long_index_t>(w_tmp) / v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
} }
} }
}
v_acc /= ck::type_convert<float>(X); v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc); arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
arg.dinput_.GetLengths()[0], arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1], arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])( arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.doutput_.GetLengths()[2]; template <ck::index_t NDimSpatial_,
std::size_t Wo = arg.doutput_.GetLengths()[3]; typename std::enable_if<NDimSpatial_ == 2, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0; float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
if(h_tmp % arg.window_strides_[0] == 0)
{ {
auto h_tmp = static_cast<ck::long_index_t>(hi) + auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.window_strides_[0]);
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]); if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
if(h_tmp % arg.window_strides_[0] == 0)
{ {
auto ho = static_cast<ck::long_index_t>(h_tmp) / for(std::size_t x = 0; x < X; ++x)
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{ {
auto w_tmp = auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(arg.window_strides_[1]);
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) - if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{ {
auto wo = v_acc +=
static_cast<ck::long_index_t>(w_tmp) / ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
} }
} }
} }
} }
} }
}
v_acc /= ck::type_convert<float>(Y * X); v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc); arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.dinput_.GetLengths()[0], arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1], arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2], arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])( arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
} }
else if constexpr(NDimSpatial == 3)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z) template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 3, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
if(d_tmp % arg.window_strides_[0] == 0)
{ {
auto d_tmp = static_cast<ck::long_index_t>(di) + auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.window_strides_[0]);
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]); if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
if(d_tmp % arg.window_strides_[0] == 0)
{ {
auto do_ = static_cast<ck::long_index_t>(d_tmp) / for(std::size_t y = 0; y < Y; ++y)
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(std::size_t y = 0; y < Y; ++y) auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{ {
auto h_tmp = auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(hi) + static_cast<ck::long_index_t>(arg.window_strides_[1]);
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) - if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{ {
auto ho = for(std::size_t x = 0; x < X; ++x)
static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) auto w_tmp = static_cast<ck::long_index_t>(wi) +
{ static_cast<ck::long_index_t>(
auto w_tmp = static_cast<ck::long_index_t>(wi) + arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) - x * arg.window_dilations_[2]);
static_cast<ck::long_index_t>(
x * arg.window_dilations_[2]);
if(w_tmp % arg.window_strides_[2] == 0) if(w_tmp % arg.window_strides_[2] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{ {
auto wo = static_cast<ck::long_index_t>(w_tmp) / v_acc += ck::type_convert<float>(
static_cast<ck::long_index_t>( arg.doutput_(n, c, do_, ho, wo));
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
} }
} }
} }
...@@ -225,21 +222,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -225,21 +222,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
} }
} }
} }
}
v_acc /= ck::type_convert<float>(Z * Y * X); v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc); arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
arg.dinput_.GetLengths()[0], arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1], arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2], arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3], arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])( arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
}
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
} }
return RunAvgPoolBwd<NDimSpatial>(arg);
} }
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment