Commit 0bd6d2ce authored by rocking's avatar rocking
Browse files

Refactor reference code for different dimension

parent da0ddd48
......@@ -58,165 +58,162 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
using Argument = ReferenceAvgPoolBwd::Argument;
float Run(const Argument& arg)
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 1, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.doutput_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0;
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
for(std::size_t x = 0; x < X; ++x)
if(w_tmp % arg.window_strides_[0] == 0)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(w_tmp % arg.window_strides_[0] == 0)
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
}
}
v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
};
v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_ncw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
return 0;
}
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 2, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0;
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
if(h_tmp % arg.window_strides_[0] == 0)
{
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
if(h_tmp % arg.window_strides_[0] == 0)
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
for(std::size_t x = 0; x < X; ++x)
{
for(std::size_t x = 0; x < X; ++x)
auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{
auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
auto wo =
static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_nchw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 3)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
return 0;
}
for(std::size_t z = 0; z < Z; ++z)
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 3, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
if(d_tmp % arg.window_strides_[0] == 0)
{
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
if(d_tmp % arg.window_strides_[0] == 0)
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
for(std::size_t y = 0; y < Y; ++y)
{
for(std::size_t y = 0; y < Y; ++y)
auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{
auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
auto ho =
static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
for(std::size_t x = 0; x < X; ++x)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.window_dilations_[2]);
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.window_dilations_[2]);
if(w_tmp % arg.window_strides_[2] == 0)
if(w_tmp % arg.window_strides_[2] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
}
}
......@@ -225,21 +222,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
}
}
}
}
v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_ncdhw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
return 0;
}
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
return RunAvgPoolBwd<NDimSpatial>(arg);
}
float Run(const device::BaseArgument* p_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment