"...composable_kernel_rocm.git" did not exist on "e8cddfdc3bc7fbdec765ee0bfbb391ef7173b455"
Commit af9b4f25 authored by rocking's avatar rocking
Browse files

Refine naming

parent 6ab0ace0
......@@ -14,52 +14,43 @@ namespace ck {
namespace tensor_operation {
namespace host {
// input descriptor in [N, C, Do, Ho, Wo] order
// output descriptor in [N, C, Di, Hi, Wi] order
// dinput descriptor in [N, C, Do, Ho, Wo] order
// doutput descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename InDataType,
typename OutDataType,
typename InElementwiseOperation,
typename OutElementwiseOperation,
typename DInDataType,
typename DOutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceAvgPoolBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(Tensor<InDataType>& input,
const Tensor<OutDataType>& output,
Argument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op)
: input_{input},
output_{output},
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
: dinput_{dinput},
doutput_{doutput},
window_spatial_lengths_{window_spatial_lengths},
conv_strides_{window_strides},
conv_dilations_{window_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
in_element_op_{in_element_op},
out_element_op_{out_element_op}
window_strides_{window_strides},
window_dilations_{window_dilations},
in_left_pads_{dinput_left_pads},
in_right_pads_{dinput_right_pads}
{
}
Tensor<InDataType>& input_;
const Tensor<OutDataType>& output_;
Tensor<DInDataType>& dinput_;
const Tensor<DOutDataType>& doutput_;
std::vector<ck::index_t> window_spatial_lengths_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> window_strides_;
std::vector<index_t> window_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
OutElementwiseOperation out_element_op_;
};
// Invoker
......@@ -69,8 +60,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
float Run(const Argument& arg)
{
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 2))
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
......@@ -79,7 +70,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.output_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0;
......@@ -87,35 +78,28 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]);
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0)
if(w_tmp % arg.window_strides_[0] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
float v_out = 0;
arg.out_element_op_(v_out,
ck::type_convert<float>(arg.output_(n, c, wo)));
v_acc += v_out;
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
}
}
v_acc /= ck::type_convert<float>(X);
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2])(
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
......@@ -126,8 +110,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.output_.GetLengths()[2];
std::size_t Wo = arg.output_.GetLengths()[3];
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0;
......@@ -135,11 +119,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0)
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
if(h_tmp % arg.window_strides_[0] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
......@@ -147,40 +131,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0)
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{
auto wo =
static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
float v_out = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(arg.output_(n, c, ho, wo)));
v_acc += v_out;
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Y * X);
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3])(
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
......@@ -192,9 +168,9 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.output_.GetLengths()[2];
std::size_t Ho = arg.output_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
......@@ -202,11 +178,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0)
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
if(d_tmp % arg.window_strides_[0] == 0)
{
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]);
static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
......@@ -214,12 +190,12 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0)
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{
auto ho =
static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]);
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
......@@ -228,23 +204,18 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]);
x * arg.window_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0)
if(w_tmp % arg.window_strides_[2] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.conv_strides_[2]);
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
float v_out = 0;
arg.out_element_op_(
v_out,
ck::type_convert<float>(
arg.output_(n, c, do_, ho, wo)));
v_acc += v_out;
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
}
}
......@@ -254,21 +225,17 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
}
}
}
v_acc /= ck::type_convert<float>(Z * Y * X);
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.input_.GetLengths()[0],
arg.input_.GetLengths()[1],
arg.input_.GetLengths()[2],
arg.input_.GetLengths()[3],
arg.input_.GetLengths()[4])(
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
......@@ -290,30 +257,26 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<InDataType>& input,
const Tensor<OutDataType>& output,
static auto MakeArgument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op)
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
{
if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial ||
dinput_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return Argument{input,
output,
return Argument{dinput,
doutput,
window_spatial_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
in_element_op,
out_element_op};
dinput_left_pads,
dinput_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment