Commit dd6a8de4 authored by Jehandad Khan's avatar Jehandad Khan
Browse files

Merge branch 'develop' into jd/dev_pkg

parents 0aa899aa abf4bdb9
...@@ -398,7 +398,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k, ...@@ -398,7 +398,7 @@ void device_gemm_xdlops_mk_kn_mn(const Tensor<ABType>& a_m_k,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -230,7 +230,7 @@ void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k, ...@@ -230,7 +230,7 @@ void device_gemm_xdlops_mk_kn_nm(const Tensor<ABType>& a_m_k,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -499,7 +499,7 @@ void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k, ...@@ -499,7 +499,7 @@ void device_gemm_xdlops_mk_nk_mn(const Tensor<ABType>& a_m_k,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -286,7 +286,7 @@ void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k, ...@@ -286,7 +286,7 @@ void device_gemm_xdlops_mk_nk_nm(const Tensor<ABType>& a_m_k,
ABType, ABType,
AccType, AccType,
CType, CType,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_k0_m_k1_grid_desc), decltype(a_k0_m_k1_grid_desc),
decltype(b_k0_n_k1_grid_desc), decltype(b_k0_n_k1_grid_desc),
decltype(c_m_n_grid_desc), decltype(c_m_n_grid_desc),
......
...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize, ...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1, typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1, typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1, typename CGridDesc_GM0_GM1_GN0_GN1,
......
...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize, ...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck::index_t ABlockTransferDstScalarPerVector_E2, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K, ck::index_t CThreadTransferDstScalarPerVector_K,
ck::ActivTypeEnum_t activ_type> ck::ActivTypeEnum activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
{ {
template <typename... Wei, template <typename... Wei,
...@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -294,7 +294,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_e0_e1_k_e2_grid_desc), decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc), decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
......
...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize, ...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck::index_t ABlockTransferDstScalarPerVector_E2, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K, ck::index_t CThreadTransferDstScalarPerVector_K,
ck::ActivTypeEnum_t activ_type> ck::ActivTypeEnum activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad
{ {
template <typename... Wei, template <typename... Wei,
...@@ -260,7 +260,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -260,7 +260,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_e0_e1_k_e2_grid_desc), decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc), decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
......
...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize, ...@@ -27,7 +27,7 @@ template <ck::index_t BlockSize,
ck::index_t ABlockTransferDstScalarPerVector_E2, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K, ck::index_t CThreadTransferDstScalarPerVector_K,
ck::ActivTypeEnum_t activ_type> ck::ActivTypeEnum activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
{ {
template <typename... Wei, template <typename... Wei,
...@@ -305,7 +305,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -305,7 +305,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
InMemoryDataOperationEnum_t::Set, InMemoryDataOperationEnum::Set,
decltype(a_e0_e1_k_e2_grid_desc), decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc), decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc), decltype(c_k_n_hop_wop_grid_desc),
......
...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize, ...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CMNGridDesc, typename CMNGridDesc,
......
...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize, ...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AK0MK1GridDesc, typename AK0MK1GridDesc,
typename BK0NK1GridDesc, typename BK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
......
...@@ -11,7 +11,7 @@ template <ck::index_t BlockSize, ...@@ -11,7 +11,7 @@ template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K, typename BGridDesc_K0_N_K,
typename CMNGridDesc, typename CMNGridDesc,
......
...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize, ...@@ -10,7 +10,7 @@ template <ck::index_t BlockSize,
typename FloatAB, typename FloatAB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation, ck::InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ABK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CMNGridDesc, typename CMNGridDesc,
......
...@@ -17,7 +17,7 @@ template <typename InDataType, ...@@ -17,7 +17,7 @@ template <typename InDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation>
struct ReferenceConvWrw : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
...@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator ...@@ -62,7 +62,7 @@ struct ReferenceConvWrw : public device::BaseOperator
// Invoker // Invoker
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
{ {
using Argument = ReferenceConvWrw::Argument; using Argument = ReferenceConvBwdWeight::Argument;
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
...@@ -163,7 +163,7 @@ struct ReferenceConvWrw : public device::BaseOperator ...@@ -163,7 +163,7 @@ struct ReferenceConvWrw : public device::BaseOperator
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "ReferenceConvFwd" str << "ReferenceConvBwdWeight"
<< std::endl; << std::endl;
// clang-format on // clang-format on
......
...@@ -14,17 +14,20 @@ namespace host { ...@@ -14,17 +14,20 @@ namespace host {
template <typename InDataType, template <typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename AccDataType,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation> typename OutElementwiseOperation,
ck::index_t NumDimSpatial = 2,
typename ck::enable_if<NumDimSpatial >= 1 && NumDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdData : public device::BaseOperator struct ReferenceConvBwdData : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(Tensor<InDataType>& in_n_c_hi_wi, Argument(Tensor<InDataType>& input,
const Tensor<WeiDataType>& wei_k_c_y_x, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
...@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -32,9 +35,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
: in_n_c_hi_wi_{in_n_c_hi_wi}, : input_{input},
wei_k_c_y_x_{wei_k_c_y_x}, weight_{weight},
out_n_k_ho_wo_{out_n_k_ho_wo}, output_{output},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -45,9 +48,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
{ {
} }
Tensor<InDataType>& in_n_c_hi_wi_; Tensor<InDataType>& input_;
const Tensor<WeiDataType>& wei_k_c_y_x_; const Tensor<WeiDataType>& weight_;
const Tensor<OutDataType>& out_n_k_ho_wo_; const Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
...@@ -66,67 +69,199 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -66,67 +69,199 @@ struct ReferenceConvBwdData : public device::BaseOperator
float Run(const Argument& arg) float Run(const Argument& arg)
{ {
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { if constexpr(NumDimSpatial == 1)
std::size_t K = arg.wei_k_c_y_x_.mDesc.GetLengths()[0]; {
std::size_t Y = arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t X = arg.weight_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[2];
std::size_t Ho = arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; AccDataType v_acc = 0;
std::size_t Wo = arg.out_n_k_ho_wo_.mDesc.GetLengths()[3];
float v_acc = 0; for(int x = 0; x < X; ++x)
{
int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0];
if(w_tmp % arg.conv_strides_[0] == 0)
{
int wo = w_tmp / arg.conv_strides_[0];
if(wo >= 0 && wo < Wo)
{
for(int k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<AccDataType>(arg.output_(n, k, wo)));
arg.wei_element_op_(
v_wei, ck::type_convert<AccDataType>(arg.weight_(k, c, x)));
v_acc += v_out * v_wei;
}
}
}
}
for(int y = 0; y < Y; ++y) float v_in;
{ arg.in_element_op_(v_in, v_acc);
int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
if(h_tmp % arg.conv_strides_[0] == 0) };
make_ParallelTensorFunctor(f_ncw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 2)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Y = arg.weight_.mDesc.GetLengths()[2];
std::size_t X = arg.weight_.mDesc.GetLengths()[3];
std::size_t Ho = arg.output_.mDesc.GetLengths()[2];
std::size_t Wo = arg.output_.mDesc.GetLengths()[3];
AccDataType v_acc = 0;
for(int y = 0; y < Y; ++y)
{ {
int ho = h_tmp / arg.conv_strides_[0]; int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0];
if(ho >= 0 && ho < Ho) if(h_tmp % arg.conv_strides_[0] == 0)
{ {
for(int x = 0; x < X; ++x) int ho = h_tmp / arg.conv_strides_[0];
if(ho >= 0 && ho < Ho)
{ {
int w_tmp = wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; for(int x = 0; x < X; ++x)
if(w_tmp % arg.conv_strides_[1] == 0)
{ {
int wo = w_tmp / arg.conv_strides_[1]; int w_tmp =
if(wo >= 0 && wo < Wo) wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1];
if(w_tmp % arg.conv_strides_[1] == 0)
{ {
for(int k = 0; k < K; ++k) int wo = w_tmp / arg.conv_strides_[1];
if(wo >= 0 && wo < Wo)
{ {
float v_out = 0; for(int k = 0; k < K; ++k)
float v_wei = 0; {
AccDataType v_out = 0;
arg.out_element_op_( AccDataType v_wei = 0;
v_out,
ck::type_convert<float>( arg.out_element_op_(v_out,
arg.out_n_k_ho_wo_(n, k, ho, wo))); ck::type_convert<AccDataType>(
arg.wei_element_op_(v_wei, arg.output_(n, k, ho, wo)));
ck::type_convert<float>( arg.wei_element_op_(v_wei,
arg.wei_k_c_y_x_(k, c, y, x))); ck::type_convert<AccDataType>(
arg.weight_(k, c, y, x)));
v_acc += v_out * v_wei;
v_acc += v_out * v_wei;
}
}
}
}
}
}
}
AccDataType v_in;
arg.in_element_op_(v_in, v_acc);
arg.input_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_nchw,
arg.input_.mDesc.GetLengths()[0],
arg.input_.mDesc.GetLengths()[1],
arg.input_.mDesc.GetLengths()[2],
arg.input_.mDesc.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NumDimSpatial == 3)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = arg.weight_.mDesc.GetLengths()[0];
std::size_t Z = arg.weight_.mDesc.GetLengths()[2];
std::size_t Y = arg.weight_.mDesc.GetLengths()[3];
std::size_t X = arg.weight_.mDesc.GetLengths()[4];
std::size_t Do = arg.output_.mDesc.GetLengths()[2];
std::size_t Ho = arg.output_.mDesc.GetLengths()[3];
std::size_t Wo = arg.output_.mDesc.GetLengths()[4];
AccDataType v_acc = 0;
for(int z = 0; z < Z; ++z)
{
int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0];
if(d_tmp % arg.conv_strides_[0] == 0)
{
int do_ = d_tmp / arg.conv_strides_[0];
if(do_ >= 0 && do_ < Do)
{
for(int y = 0; y < Y; ++y)
{
int h_tmp =
hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1];
if(h_tmp % arg.conv_strides_[1] == 0)
{
int ho = h_tmp / arg.conv_strides_[1];
if(ho >= 0 && ho < Ho)
{
for(int x = 0; x < X; ++x)
{
int w_tmp = wi + arg.in_left_pads_[2] -
x * arg.conv_dilations_[2];
if(w_tmp % arg.conv_strides_[2] == 0)
{
int wo = w_tmp / arg.conv_strides_[2];
if(wo >= 0 && wo < Wo)
{
for(int k = 0; k < K; ++k)
{
AccDataType v_out = 0;
AccDataType v_wei = 0;
arg.out_element_op_(
v_out,
ck::type_convert<AccDataType>(
arg.output_(
n, k, do_, ho, wo)));
arg.wei_element_op_(
v_wei,
ck::type_convert<AccDataType>(
arg.weight_(k, c, z, y, x)));
v_acc += v_out * v_wei;
}
}
}
}
} }
} }
} }
} }
} }
} }
}
float v_in; AccDataType v_in;
arg.in_element_op_(v_in, v_acc); arg.in_element_op_(v_in, v_acc);
arg.in_n_c_hi_wi_(n, c, hi, wi) = ck::type_convert<InDataType>(v_in); arg.input_(n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_nchw, make_ParallelTensorFunctor(f_ncdhw,
arg.in_n_c_hi_wi_.mDesc.GetLengths()[0], arg.input_.mDesc.GetLengths()[0],
arg.in_n_c_hi_wi_.mDesc.GetLengths()[1], arg.input_.mDesc.GetLengths()[1],
arg.in_n_c_hi_wi_.mDesc.GetLengths()[2], arg.input_.mDesc.GetLengths()[2],
arg.in_n_c_hi_wi_.mDesc.GetLengths()[3])( arg.input_.mDesc.GetLengths()[3],
std::thread::hardware_concurrency()); arg.input_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0; return 0;
}
} }
float Run(const device::BaseArgument* p_arg, int, hipStream_t, bool) override float Run(const device::BaseArgument* p_arg, int, hipStream_t, bool) override
...@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -143,9 +278,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
bool IsSupportedArgument(const device::BaseArgument*) override { return true; } bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<InDataType>& in_n_c_hi_wi, static auto MakeArgument(Tensor<InDataType>& input,
const Tensor<WeiDataType>& wei_k_c_y_x, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
...@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -154,9 +289,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) OutElementwiseOperation out_element_op)
{ {
return Argument{in_n_c_hi_wi, return Argument{input,
wei_k_c_y_x, weight,
out_n_k_ho_wo, output,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -14,9 +14,9 @@ namespace host { ...@@ -14,9 +14,9 @@ namespace host {
// //
// @brief Reference implementation for forward convolution. // @brief Reference implementation for forward convolution.
// //
// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout. // @paragraph Supports both NCHW as well as NHWC formats (and their respective
// Weights tensor supports KCYX data layout. Output tensor supports // counterparts for weight and output) as long as tensor descriptor
// NKHoWo data layout. // lengths is in NCHW.
// //
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
...@@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_wei; float v_wei;
arg.in_element_op_(v_in, arg.in_element_op_(v_in,
static_cast<const float>(arg.input_(n, c, wi))); ck::type_convert<float>(arg.input_(n, c, wi)));
arg.wei_element_op_(v_wei, arg.wei_element_op_(v_wei,
static_cast<const float>(arg.weight_(k, c, x))); ck::type_convert<float>(arg.weight_(k, c, x)));
v_acc += v_in * v_wei; v_acc += v_in * v_wei;
} }
...@@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator
float v_out; float v_out;
arg.out_element_op_(v_out, v_acc); arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, wo) = v_out; arg.output_(n, k, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
...@@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator
return 0; return 0;
} }
else if constexpr(NumDimSpatial == 3)
{
auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c)
{
for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z)
{
int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] -
arg.in_left_pads_[0];
for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y)
{
int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] -
arg.in_left_pads_[1];
for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x)
{
int wi = wo * arg.conv_strides_[2] +
x * arg.conv_dilations_[2] - arg.in_left_pads_[2];
if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] &&
hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] &&
wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4])
{
float v_in;
float v_wei;
arg.in_element_op_(
v_in,
ck::type_convert<float>(arg.input_(n, c, di, hi, wi)));
arg.wei_element_op_(
v_wei,
ck::type_convert<float>(arg.weight_(k, c, z, y, x)));
v_acc += v_in * v_wei;
}
}
}
}
}
float v_out;
arg.out_element_op_(v_out, v_acc);
arg.output_(n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
};
make_ParallelTensorFunctor(f_nchw,
arg.output_.mDesc.GetLengths()[0],
arg.output_.mDesc.GetLengths()[1],
arg.output_.mDesc.GetLengths()[2],
arg.output_.mDesc.GetLengths()[3],
arg.output_.mDesc.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
} }
float Run(const device::BaseArgument* p_arg, int, hipStream_t, bool) override float Run(const device::BaseArgument* p_arg, int, hipStream_t, bool) override
......
...@@ -6,23 +6,36 @@ ...@@ -6,23 +6,36 @@
#include "device_reduce_instance_blockwise_f32_f32_f32.hpp" #include "device_reduce_instance_blockwise_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_f32_f64_f32.hpp" #include "device_reduce_instance_blockwise_f32_f64_f32.hpp"
#include "device_reduce_instance_blockwise_f64_f64_f64.hpp" #include "device_reduce_instance_blockwise_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp" #include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp" #include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp" #include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp" #include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp" #include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp" #include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp" #include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp" #include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp" #include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" #include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
#include "device_reduce_instance_threadwise_f32_f64_f32.hpp" #include "device_reduce_instance_threadwise_f32_f64_f32.hpp"
#include "device_reduce_instance_threadwise_f64_f64_f64.hpp" #include "device_reduce_instance_threadwise_f64_f64_f64.hpp"
#include "device_reduce_instance_threadwise_i8_i8_i8.hpp"
#include "device_reduce_instance_threadwise_i8_i32_i8.hpp"
#include "device_reduce_instance_threadwise_b16_f32_b16.hpp"
#endif #endif
...@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_blockwise = std::tuple< ...@@ -17,7 +17,6 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
ReductionConfiguration_2<0, 2, 2, 2, 1>, ReductionConfiguration_2<0, 2, 2, 2, 1>,
ReductionConfiguration_2<0, 1, 1, 2, 1>, ReductionConfiguration_2<0, 1, 1, 2, 1>,
ReductionConfiguration_2<1, 2, 1, 1, 2>, ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 2, 2, 1, 2>,
ReductionConfiguration_2<0, 1, 1, 3, 1>, ReductionConfiguration_2<0, 1, 1, 3, 1>,
ReductionConfiguration_2<1, 1, 1, 1, 3> ReductionConfiguration_2<1, 1, 1, 1, 3>
// clang-format on // clang-format on
...@@ -48,7 +47,7 @@ using reduce_configuration_2_instances_blockwise = std::tuple< ...@@ -48,7 +47,7 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp_t ReduceOpId> template <typename AccDataType, ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr< using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>;
...@@ -57,10 +56,10 @@ template <typename InDataType, ...@@ -57,10 +56,10 @@ template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
typename ReduceDims, int NumReduceDim,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt> ReduceTensorIndices IndicesOpt>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{ {
...@@ -72,11 +71,11 @@ void add_device_reduce_instance_blockwise( ...@@ -72,11 +71,11 @@ void add_device_reduce_instance_blockwise(
AccElementwiseOperation; AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::NOT_PROPAGATE_NAN) ? false : true; constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) { static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 = using cfg1 =
...@@ -91,7 +90,7 @@ void add_device_reduce_instance_blockwise( ...@@ -91,7 +90,7 @@ void add_device_reduce_instance_blockwise(
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
ReduceDims, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
...@@ -112,34 +111,36 @@ void add_device_reduce_instance_blockwise( ...@@ -112,34 +111,36 @@ void add_device_reduce_instance_blockwise(
}); });
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ #define ADD_BLOCKWISE_INST_BY_TYPE( \
template void add_device_reduce_instance_blockwise<inT, \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \ template void add_device_reduce_instance_blockwise<inT, \
outT, \ compT, \
Rank, \ outT, \
Sequence<__VA_ARGS__>, \ Rank, \
ReduceOpId, \ NumReduceDim, \
NanOpt, \ ReduceOpId, \
IndicesOpt>( \ NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ #define ADD_BLOCKWISE_INST_BY_ID( \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \ ADD_BLOCKWISE_INST_BY_TYPE(inT, \
outT, \ compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \ outT, \
static_cast<NanPropagation_t>(NanOpt), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<NanPropagation>(NanOpt), \
Rank, \ static_cast<ReduceTensorIndices>(IndicesOpt), \
__VA_ARGS__) Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
Sequence<__VA_ARGS__>, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ NanOpt, \
IndicesOpt>( \ IndicesOpt>( \
...@@ -149,15 +150,16 @@ void add_device_reduce_instance_blockwise( ...@@ -149,15 +150,16 @@ void add_device_reduce_instance_blockwise(
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID(inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, ...) \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
compT, \ ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
outT, \ compT, \
static_cast<ReduceTensorOp_t>(ReduceOpId), \ outT, \
static_cast<NanPropagation_t>(NanOpt), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<ReduceTensorIndices_t>(IndicesOpt), \ static_cast<NanPropagation>(NanOpt), \
Rank, \ static_cast<ReduceTensorIndices>(IndicesOpt), \
__VA_ARGS__) Rank, \
NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
} // namespace device } // namespace device
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -11,25 +11,31 @@ namespace device { ...@@ -11,25 +11,31 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0, 1, 2); // for MIN ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0, 1, 2); // for MAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0, 1, 2); // for AMAX ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
...@@ -11,16 +11,19 @@ namespace device { ...@@ -11,16 +11,19 @@ namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | ReduceDims // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0, 1, 2); // for ADD ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 0); ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0, 1, 2); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0, 1, 2); // for NORM2 ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 0); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); // ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment