Commit c7bf4232 authored by letaoqin's avatar letaoqin
Browse files

device transfer elementwiseop to gridwise

parent f8aef548
...@@ -17,7 +17,8 @@ using S = ck::Sequence<Is...>; ...@@ -17,7 +17,8 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough; using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::Relu;
;
static constexpr auto ConvSpec = static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
......
...@@ -103,6 +103,9 @@ template <typename GridwiseGemm, ...@@ -103,6 +103,9 @@ template <typename GridwiseGemm,
typename ABDataType, typename ABDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_K0_M0_M1_K1, typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1, typename BGridDesc_K0_N0_N1_K1,
typename DsGridDesc_M0_M10_M11_N0_N10_N11, typename DsGridDesc_M0_M10_M11_N0_N10_N11,
...@@ -120,6 +123,9 @@ __global__ void ...@@ -120,6 +123,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count, const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
...@@ -160,6 +166,9 @@ __global__ void ...@@ -160,6 +166,9 @@ __global__ void
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + c_batch_offset, p_e_grid + c_batch_offset,
p_shared, p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
ds_grid_desc_m0_m10_m11_n0_n10_n11, ds_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -172,6 +181,9 @@ __global__ void ...@@ -172,6 +181,9 @@ __global__ void
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = batch_count; ignore = batch_count;
ignore = a_grid_desc_k0_m0_m1_k1; ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1; ignore = b_grid_desc_k0_n0_n1_k1;
...@@ -212,10 +224,10 @@ template <index_t NDimSpatial, ...@@ -212,10 +224,10 @@ template <index_t NDimSpatial,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DsDataType, typename DsDataType,
typename CLayout, typename ELayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDElementwiseOperation, typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization, ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
index_t BlockSize, index_t BlockSize,
...@@ -250,14 +262,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -250,14 +262,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
CLayout, ELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, DsDataType,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDElementwiseOperation> CDEElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK; using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK;
...@@ -338,13 +350,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -338,13 +350,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
template <typename CLay> template <typename ELay>
static auto static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides) const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides); e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
...@@ -373,7 +385,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -373,7 +385,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<CLayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm =
...@@ -382,6 +394,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -382,6 +394,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
AccDataType, AccDataType,
DsLayout, DsLayout,
EDataType, EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1, AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1, BGridDesc_BK0_N_BK1,
...@@ -447,7 +462,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -447,7 +462,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial>& input_right_pads, const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDElementwiseOperation& cd_element_op) const CDEElementwiseOperation& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)}, : p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)}, p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{}, p_ds_grid_{},
...@@ -466,7 +481,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -466,7 +481,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads)}, input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>( b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<CLayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)}, e_g_n_k_wos_strides)},
a_grid_desc_k0_m0_m1_k1_{}, a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{}, b_grid_desc_k0_n0_n1_k1_{},
...@@ -476,7 +491,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -476,7 +491,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_{}, compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cd_element_op_{cd_element_op}, cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
...@@ -570,7 +585,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -570,7 +585,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDElementwiseOperation cd_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_; std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
...@@ -621,6 +636,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -621,6 +636,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_K0_M0_M1_K1, DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1, DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
...@@ -639,6 +657,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -639,6 +657,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
arg.p_b_grid_, arg.p_b_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_k0_m0_m1_k1_, arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_, arg.b_grid_desc_k0_n0_n1_k1_,
...@@ -796,11 +817,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -796,11 +817,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
// check vector access of E // check vector access of E
if constexpr(is_same_v<CLayout, ctc::G_NW_K> || is_same_v<CLayout, ctc::G_NHW_K> || if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<CLayout, ctc::G_NDHW_K> || is_same_v<CLayout, ctc::GNWK> || is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<CLayout, ctc::GNHWK> || is_same_v<CLayout, ctc::GNDHWK> || is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<CLayout, ctc::NWGK> || is_same_v<CLayout, ctc::NHWGK> || is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<CLayout, ctc::NDHWGK>) is_same_v<ELayout, ctc::NDHWGK>)
{ {
const index_t K = arg.e_g_n_k_wos_lengths_[2]; const index_t K = arg.e_g_n_k_wos_lengths_[2];
...@@ -842,7 +863,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -842,7 +863,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial>& input_right_pads, const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDElementwiseOperation& cd_element_op) const CDEElementwiseOperation& cde_element_op)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -862,7 +883,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -862,7 +883,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads, input_right_pads,
a_element_op, a_element_op,
b_element_op, b_element_op,
cd_element_op}; cde_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -886,7 +907,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -886,7 +907,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial>& input_right_pads, const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDElementwiseOperation& cd_element_op) override const CDEElementwiseOperation& cde_element_op) override
{ {
return std::make_unique<Argument>(p_a, return std::make_unique<Argument>(p_a,
p_b, p_b,
...@@ -906,7 +927,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -906,7 +927,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads, input_right_pads,
a_element_op, a_element_op,
b_element_op, b_element_op,
cd_element_op); cde_element_op);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -22,6 +22,9 @@ template <index_t BlockSize, ...@@ -22,6 +22,9 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename DsDataType, typename DsDataType,
typename FloatC, typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1, typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1, typename BGridDesc_K0_N_K1,
...@@ -247,6 +250,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -247,6 +250,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11, const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -257,6 +263,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -257,6 +263,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
{ {
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment