Commit c7bf4232 authored by letaoqin's avatar letaoqin
Browse files

device transfer elementwiseop to gridwise

parent f8aef548
......@@ -17,7 +17,8 @@ using S = ck::Sequence<Is...>;
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::Relu;
;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
......
......@@ -103,6 +103,9 @@ template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1,
typename DsGridDesc_M0_M10_M11_N0_N10_N11,
......@@ -120,6 +123,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
......@@ -160,6 +166,9 @@ __global__ void
p_ds_grid_grp,
p_e_grid + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1,
ds_grid_desc_m0_m10_m11_n0_n10_n11,
......@@ -172,6 +181,9 @@ __global__ void
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = batch_count;
ignore = a_grid_desc_k0_m0_m1_k1;
ignore = b_grid_desc_k0_n0_n1_k1;
......@@ -212,10 +224,10 @@ template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsDataType,
typename CLayout,
typename ELayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t BlockSize,
......@@ -250,14 +262,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
ALayout,
BLayout,
DsLayout,
CLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDElementwiseOperation>
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK;
......@@ -338,13 +350,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <typename CLay>
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(e_g_n_k_wos_lengths,
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc =
......@@ -373,7 +385,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<CLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm =
......@@ -382,6 +394,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
AccDataType,
DsLayout,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
......@@ -447,7 +462,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDElementwiseOperation& cd_element_op)
const CDEElementwiseOperation& cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a)},
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_ds_grid_{},
......@@ -466,7 +481,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<CLayout>(e_g_n_k_wos_lengths,
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
......@@ -476,7 +491,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cd_element_op_{cd_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
......@@ -570,7 +585,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDElementwiseOperation cd_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
......@@ -621,6 +636,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_K0_M0_M1_K1,
DeviceOp::BGridDesc_K0_N0_N1_K1,
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
......@@ -639,6 +657,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
......@@ -796,11 +817,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
// check vector access of E
if constexpr(is_same_v<CLayout, ctc::G_NW_K> || is_same_v<CLayout, ctc::G_NHW_K> ||
is_same_v<CLayout, ctc::G_NDHW_K> || is_same_v<CLayout, ctc::GNWK> ||
is_same_v<CLayout, ctc::GNHWK> || is_same_v<CLayout, ctc::GNDHWK> ||
is_same_v<CLayout, ctc::NWGK> || is_same_v<CLayout, ctc::NHWGK> ||
is_same_v<CLayout, ctc::NDHWGK>)
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
......@@ -842,7 +863,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDElementwiseOperation& cd_element_op)
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_a,
p_b,
......@@ -862,7 +883,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads,
a_element_op,
b_element_op,
cd_element_op};
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -886,7 +907,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDElementwiseOperation& cd_element_op) override
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
......@@ -906,7 +927,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads,
a_element_op,
b_element_op,
cd_element_op);
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
......@@ -22,6 +22,9 @@ template <index_t BlockSize,
typename FloatAcc,
typename DsDataType,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
......@@ -247,6 +250,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
DsGridPointer p_ds_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11,
......@@ -257,6 +263,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
{
ignore = p_ds_grid;
ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment