Commit 0bd6d2ce authored by rocking's avatar rocking
Browse files

Refactor reference code for different dimension

parent da0ddd48
......@@ -58,15 +58,9 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
using Argument = ReferenceAvgPoolBwd::Argument;
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 1, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
......@@ -104,7 +98,10 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
return 0;
}
else if constexpr(NDimSpatial == 2)
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 2, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
......@@ -134,8 +131,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{
auto wo =
static_cast<ck::long_index_t>(w_tmp) /
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
......@@ -161,7 +157,10 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
return 0;
}
else if constexpr(NDimSpatial == 3)
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 3, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
......@@ -193,8 +192,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{
auto ho =
static_cast<ck::long_index_t>(h_tmp) /
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
......@@ -240,6 +238,16 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
return 0;
}
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
return RunAvgPoolBwd<NDimSpatial>(arg);
}
float Run(const device::BaseArgument* p_arg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment