Commit 734df790 authored by root's avatar root
Browse files

Add MK-KN FP16 instances.

parent f9f2cdf9
......@@ -956,7 +956,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer << ", "
<< LoopSched
<< ">";
// clang-format on
......
......@@ -30,6 +30,19 @@ void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregu
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
typename ELayout,
......@@ -74,6 +87,8 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
......
add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
......@@ -91,28 +91,28 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
#ifdef CK_ENABLE_FP16
// if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
// {
// ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
// ck::half_t,
// ck::half_t,
// float,
// ck::tensor_layout::gemm::RowMajor,
// ck::tensor_layout::gemm::RowMajor,
// ck::tensor_layout::gemm::RowMajor>(do_verification,
// init_method,
// do_log,
// time_kernel,
// Ms,
// Ns,
// Ks,
// StrideAs,
// StrideBs,
// StrideCs,
// kbatch);
// }
// else
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
ck::half_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment