"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "643ebd4f3e250ef1ade1f21ff82a4ca30d8a30c7"
Commit c20aabc3 authored by Jing Zhang's avatar Jing Zhang
Browse files

finished ckprofiler

parent 857010cc
# device_grouped_gemm_instance # device_grouped_gemm_instance
set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
#device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
#device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
#device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
) )
add_library(device_grouped_gemm_instance SHARED ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) add_library(device_grouped_gemm_instance SHARED ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE})
......
...@@ -23,12 +23,12 @@ using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemm ...@@ -23,12 +23,12 @@ using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemm
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&); std::vector<DeviceGroupedGemmNoOpPtr>&);
// void void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&); std::vector<DeviceGroupedGemmNoOpPtr>&);
// void void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&); std::vector<DeviceGroupedGemmNoOpPtr>&);
// void void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&); std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace device_grouped_gemm_instance } // namespace device_grouped_gemm_instance
} // namespace device } // namespace device
...@@ -167,65 +167,27 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -167,65 +167,27 @@ void profile_grouped_gemm_impl(int do_verification,
ck::tensor_operation::device::device_grouped_gemm_instance:: ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
#if 0
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
if(KBatch > 1) ck::tensor_operation::device::device_grouped_gemm_instance::
{ add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
if(KBatch > 1) ck::tensor_operation::device::device_grouped_gemm_instance::
{ add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
if(KBatch > 1) ck::tensor_operation::device::device_grouped_gemm_instance::
{ add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
else
{
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
} }
#endif
} }
if(gemm_ptrs.size() <= 0) if(gemm_ptrs.size() <= 0)
...@@ -238,7 +200,6 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -238,7 +200,6 @@ void profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
#if 1
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
{ {
...@@ -330,11 +291,10 @@ void profile_grouped_gemm_impl(int do_verification, ...@@ -330,11 +291,10 @@ void profile_grouped_gemm_impl(int do_verification,
std::cout << "does not support this GEMM problem" << std::endl; std::cout << "does not support this GEMM problem" << std::endl;
} }
} }
#endif
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
} } // namespace profiler
} // namespace profiler } // namespace profiler
} // namespace ck } // namespace ck
...@@ -93,192 +93,64 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -93,192 +93,64 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideBs, StrideBs,
StrideCs); StrideCs);
} }
#if 0
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(do_verification,
do_verification, init_method,
init_method, do_log,
do_log, nrepeat,
nrepeat, Ms,
M, Ns,
N, Ks,
K, StrideAs,
(StrideA < 0) ? K : StrideA, StrideBs,
(StrideB < 0) ? K : StrideB, StrideCs);
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(do_verification,
do_verification, init_method,
init_method, do_log,
do_log, nrepeat,
nrepeat, Ms,
M, Ns,
N, Ks,
K, StrideAs,
(StrideA < 0) ? M : StrideA, StrideBs,
(StrideB < 0) ? N : StrideB, StrideCs);
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t, ck::half_t,
ck::half_t, ck::half_t,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>( ck::tensor_layout::gemm::RowMajor>(do_verification,
do_verification, init_method,
init_method, do_log,
do_log, nrepeat,
nrepeat, Ms,
M, Ns,
N, Ks,
K, StrideAs,
(StrideA < 0) ? M : StrideA, StrideBs,
(StrideB < 0) ? K : StrideB, StrideCs);
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? K : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
M,
N,
K,
(StrideA < 0) ? M : StrideA,
(StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC,
KBatch);
} }
else else
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
} }
#endif
return 1; return 1;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment