"...composable_kernel.git" did not exist on "0bdbd358ecee2eb434e27bc900a3e248bcd2b2bd"
Commit af9b4f25 authored by rocking's avatar rocking
Browse files

Refine naming

parent 6ab0ace0
...@@ -14,52 +14,43 @@ namespace ck { ...@@ -14,52 +14,43 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace host { namespace host {
// input descriptor in [N, C, Do, Ho, Wo] order // dinput descriptor in [N, C, Do, Ho, Wo] order
// output descriptor in [N, C, Di, Hi, Wi] order // doutput descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent // phyiscal layout is irrelavent
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename DInDataType,
typename OutDataType, typename DOutDataType,
typename InElementwiseOperation,
typename OutElementwiseOperation,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceAvgPoolBwd : public device::BaseOperator struct ReferenceAvgPoolBwd : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(Tensor<InDataType>& input, Argument(Tensor<DInDataType>& dinput,
const Tensor<OutDataType>& output, const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths, std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations, std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> dinput_right_pads)
InElementwiseOperation in_element_op, : dinput_{dinput},
OutElementwiseOperation out_element_op) doutput_{doutput},
: input_{input},
output_{output},
window_spatial_lengths_{window_spatial_lengths}, window_spatial_lengths_{window_spatial_lengths},
conv_strides_{window_strides}, window_strides_{window_strides},
conv_dilations_{window_dilations}, window_dilations_{window_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{dinput_left_pads},
in_right_pads_{input_right_pads}, in_right_pads_{dinput_right_pads}
in_element_op_{in_element_op},
out_element_op_{out_element_op}
{ {
} }
Tensor<InDataType>& input_; Tensor<DInDataType>& dinput_;
const Tensor<OutDataType>& output_; const Tensor<DOutDataType>& doutput_;
std::vector<ck::index_t> window_spatial_lengths_; std::vector<ck::index_t> window_spatial_lengths_;
std::vector<index_t> conv_strides_; std::vector<index_t> window_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> window_dilations_;
std::vector<index_t> in_left_pads_; std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<index_t> in_right_pads_;
InElementwiseOperation in_element_op_;
OutElementwiseOperation out_element_op_;
}; };
// Invoker // Invoker
...@@ -69,8 +60,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -69,8 +60,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 2 && if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.output_.GetNumOfDimension() == NDimSpatial + 2)) arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{ {
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
...@@ -79,7 +70,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -79,7 +70,7 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{ {
auto f_ncw = [&](auto n, auto c, auto wi) { auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0]; std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.output_.GetLengths()[2]; std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0; float v_acc = 0;
...@@ -87,35 +78,28 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -87,35 +78,28 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{ {
auto w_tmp = static_cast<ck::long_index_t>(wi) + auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
if(w_tmp % arg.conv_strides_[0] == 0) if(w_tmp % arg.window_strides_[0] == 0)
{ {
auto wo = static_cast<ck::long_index_t>(w_tmp) / auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
float v_out = 0; v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
arg.out_element_op_(v_out,
ck::type_convert<float>(arg.output_(n, c, wo)));
v_acc += v_out;
} }
} }
} }
v_acc /= ck::type_convert<float>(X);
float v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc); v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
arg.input_.GetLengths()[0], arg.dinput_.GetLengths()[0],
arg.input_.GetLengths()[1], arg.dinput_.GetLengths()[1],
arg.input_.GetLengths()[2])( arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -126,8 +110,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -126,8 +110,8 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
std::size_t Y = arg.window_spatial_lengths_[0]; std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1]; std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.output_.GetLengths()[2]; std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.output_.GetLengths()[3]; std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0; float v_acc = 0;
...@@ -135,11 +119,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -135,11 +119,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{ {
auto h_tmp = static_cast<ck::long_index_t>(hi) + auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
if(h_tmp % arg.conv_strides_[0] == 0) if(h_tmp % arg.window_strides_[0] == 0)
{ {
auto ho = static_cast<ck::long_index_t>(h_tmp) / auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
...@@ -147,40 +131,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -147,40 +131,32 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
auto w_tmp = auto w_tmp =
static_cast<ck::long_index_t>(wi) + static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]); static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.conv_strides_[1] == 0) if(w_tmp % arg.window_strides_[1] == 0)
{ {
auto wo = auto wo =
static_cast<ck::long_index_t>(w_tmp) / static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo) if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{ {
float v_out = 0; v_acc +=
arg.out_element_op_( ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
v_out,
ck::type_convert<float>(arg.output_(n, c, ho, wo)));
v_acc += v_out;
} }
} }
} }
} }
} }
} }
v_acc /= ck::type_convert<float>(Y * X);
float v_in;
arg.in_element_op_(v_in, v_acc); v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_nchw,
arg.input_.GetLengths()[0], arg.dinput_.GetLengths()[0],
arg.input_.GetLengths()[1], arg.dinput_.GetLengths()[1],
arg.input_.GetLengths()[2], arg.dinput_.GetLengths()[2],
arg.input_.GetLengths()[3])( arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -192,9 +168,9 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -192,9 +168,9 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
std::size_t Y = arg.window_spatial_lengths_[1]; std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2]; std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.output_.GetLengths()[2]; std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.output_.GetLengths()[3]; std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.output_.GetLengths()[4]; std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0; float v_acc = 0;
...@@ -202,11 +178,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -202,11 +178,11 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
{ {
auto d_tmp = static_cast<ck::long_index_t>(di) + auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) - static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]); static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
if(d_tmp % arg.conv_strides_[0] == 0) if(d_tmp % arg.window_strides_[0] == 0)
{ {
auto do_ = static_cast<ck::long_index_t>(d_tmp) / auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[0]); static_cast<ck::long_index_t>(arg.window_strides_[0]);
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do) if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{ {
for(std::size_t y = 0; y < Y; ++y) for(std::size_t y = 0; y < Y; ++y)
...@@ -214,12 +190,12 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -214,12 +190,12 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
auto h_tmp = auto h_tmp =
static_cast<ck::long_index_t>(hi) + static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) - static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]); static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.conv_strides_[1] == 0) if(h_tmp % arg.window_strides_[1] == 0)
{ {
auto ho = auto ho =
static_cast<ck::long_index_t>(h_tmp) / static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.conv_strides_[1]); static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho) if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{ {
for(std::size_t x = 0; x < X; ++x) for(std::size_t x = 0; x < X; ++x)
...@@ -228,23 +204,18 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -228,23 +204,18 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) - arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]); x * arg.window_dilations_[2]);
if(w_tmp % arg.conv_strides_[2] == 0) if(w_tmp % arg.window_strides_[2] == 0)
{ {
auto wo = static_cast<ck::long_index_t>(w_tmp) / auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
arg.conv_strides_[2]); arg.window_strides_[2]);
if(wo >= 0 && if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo) ck::type_convert<std::size_t>(wo) < Wo)
{ {
float v_out = 0; v_acc += ck::type_convert<float>(
arg.out_element_op_( arg.doutput_(n, c, do_, ho, wo));
v_out,
ck::type_convert<float>(
arg.output_(n, c, do_, ho, wo)));
v_acc += v_out;
} }
} }
} }
...@@ -254,21 +225,17 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -254,21 +225,17 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
} }
} }
} }
v_acc /= ck::type_convert<float>(Z * Y * X);
float v_in; v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ncdhw, make_ParallelTensorFunctor(f_ncdhw,
arg.input_.GetLengths()[0], arg.dinput_.GetLengths()[0],
arg.input_.GetLengths()[1], arg.dinput_.GetLengths()[1],
arg.input_.GetLengths()[2], arg.dinput_.GetLengths()[2],
arg.input_.GetLengths()[3], arg.dinput_.GetLengths()[3],
arg.input_.GetLengths()[4])( arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
return 0; return 0;
...@@ -290,30 +257,26 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator ...@@ -290,30 +257,26 @@ struct ReferenceAvgPoolBwd : public device::BaseOperator
bool IsSupportedArgument(const device::BaseArgument*) override { return true; } bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<InDataType>& input, static auto MakeArgument(Tensor<DInDataType>& dinput,
const Tensor<OutDataType>& output, const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths, std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations, std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> dinput_right_pads)
InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op)
{ {
if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial) dinput_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
return Argument{input, return Argument{dinput,
output, doutput,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations, window_dilations,
input_left_pads, dinput_left_pads,
input_right_pads, dinput_right_pads};
in_element_op,
out_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment