"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "b8a5c570bf7bd086843c26fb334363e1b1f442c4"
Unverified Commit 0345963e authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Add MNK padding, M = 0 support into grouped_gemm (#539)



* add mnk padding, support m=0

* clean code

* clean code
Co-authored-by: default avatarRostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
parent 11151175
Pipeline #663 failed with stages
in 0 seconds
...@@ -373,12 +373,20 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -373,12 +373,20 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
const index_t M = gemm_descs[i].M_; const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t StrideA = gemm_descs[i].stride_A_; const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_; const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_; const index_t StrideC = gemm_descs[i].stride_C_;
...@@ -470,6 +478,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -470,6 +478,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
// private: // private:
index_t group_count_; index_t group_count_;
index_t skipped_group_count_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_; CDEElementwiseOperation c_element_op_;
...@@ -581,7 +591,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -581,7 +591,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{ {
return false; return false;
} }
......
...@@ -56,6 +56,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< ...@@ -56,6 +56,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row, Row,
...@@ -71,6 +84,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( ...@@ -71,6 +84,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -56,6 +56,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< ...@@ -56,6 +56,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col, Col,
...@@ -71,6 +84,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( ...@@ -71,6 +84,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -56,6 +56,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< ...@@ -56,6 +56,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row, Row,
...@@ -71,6 +84,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( ...@@ -71,6 +84,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -53,6 +53,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< ...@@ -53,6 +53,19 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format on // clang-format on
>; >;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col, Col,
...@@ -68,6 +81,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( ...@@ -68,6 +81,8 @@ void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -31,6 +31,9 @@ std::size_t HostTensorDescriptor::GetElementSpaceSize() const ...@@ -31,6 +31,9 @@ std::size_t HostTensorDescriptor::GetElementSpaceSize() const
std::size_t space = 1; std::size_t space = 1;
for(std::size_t i = 0; i < mLens.size(); ++i) for(std::size_t i = 0; i < mLens.size(); ++i)
{ {
if(mLens[i] == 0)
continue;
space += (mLens[i] - 1) * mStrides[i]; space += (mLens[i] - 1) * mStrides[i];
} }
return space; return space;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment