Commit 734df790 authored by root's avatar root
Browse files

Add MK-KN FP16 instances.

parent f9f2cdf9
......@@ -956,7 +956,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer << ", "
<< LoopSched
<< ">";
// clang-format on
......
......@@ -30,6 +30,19 @@ void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregu
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
......@@ -74,6 +87,8 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
......
add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
......@@ -43,7 +43,7 @@ using device_ggemm_md_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_tile_in
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
......
......@@ -91,28 +91,28 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
#ifdef CK_ENABLE_FP16
// if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
// {
// ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
// ck::half_t,
// ck::half_t,
// float,
// ck::tensor_layout::gemm::RowMajor,
// ck::tensor_layout::gemm::RowMajor,
// ck::tensor_layout::gemm::RowMajor>(do_verification,
// init_method,
// do_log,
// time_kernel,
// Ms,
// Ns,
// Ks,
// StrideAs,
// StrideBs,
// StrideCs,
// kbatch);
// }
// else
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
ck::half_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment