Commit 4511f877 authored by Chao Liu's avatar Chao Liu
Browse files

refactor profiler

parent 519b6aaf
...@@ -4,97 +4,99 @@ ...@@ -4,97 +4,99 @@
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "profile_convnd_fwd.hpp" bool profile_gemm(int, char*[]);
bool profile_gemm_splitk(int, char*[]);
int profile_gemm(int, char*[]); bool profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_2d(int, char*[]); bool profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu(int, char*[]); bool profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]); bool profile_gemm_reduce(int, char*[]);
int profile_gemm_reduce(int, char*[]); bool profile_batched_gemm(int, char*[]);
int profile_batched_gemm(int, char*[]); bool profile_grouped_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]); bool profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]); bool profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); bool profile_convnd_fwd(int argc, char* argv[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); bool profile_convnd_bwd_data(int, char*[], int);
int profile_convnd_bwd_data(int, char*[], int); bool profile_reduce(int, char*[]);
int profile_reduce(int, char*[]); bool profile_conv_bwd_weight(int, char*[]);
int profile_conv_bwd_weight(int, char*[]); bool profile_batched_gemm_reduce(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool pass = true;
if(strcmp(argv[1], "gemm") == 0) if(strcmp(argv[1], "gemm") == 0)
{ {
return profile_gemm(argc, argv); pass = profile_gemm(argc, argv);
}
if(strcmp(argv[1], "gemm_splitk") == 0)
{
pass = profile_gemm_splitk(argc, argv);
} }
else if(strcmp(argv[1], "gemm_bias_2d") == 0) else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{ {
return profile_gemm_bias_2d(argc, argv); pass = profile_gemm_bias_2d(argc, argv);
} }
else if(strcmp(argv[1], "gemm_bias_relu") == 0) else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{ {
return profile_gemm_bias_relu(argc, argv); pass = profile_gemm_bias_relu(argc, argv);
} }
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0) else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{ {
return profile_gemm_bias_relu_add(argc, argv); pass = profile_gemm_bias_relu_add(argc, argv);
} }
else if(strcmp(argv[1], "gemm_reduce") == 0) else if(strcmp(argv[1], "gemm_reduce") == 0)
{ {
return profile_gemm_reduce(argc, argv); pass = profile_gemm_reduce(argc, argv);
} }
else if(strcmp(argv[1], "batched_gemm") == 0) else if(strcmp(argv[1], "batched_gemm") == 0)
{ {
return profile_batched_gemm(argc, argv); pass = profile_batched_gemm(argc, argv);
} }
else if(strcmp(argv[1], "batched_gemm_reduce") == 0) else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{ {
return profile_batched_gemm_reduce(argc, argv); pass = profile_batched_gemm_reduce(argc, argv);
} }
else if(strcmp(argv[1], "grouped_gemm") == 0) else if(strcmp(argv[1], "grouped_gemm") == 0)
{ {
profile_grouped_gemm(argc, argv); pass = profile_grouped_gemm(argc, argv);
} }
else if(strcmp(argv[1], "conv_fwd") == 0) else if(strcmp(argv[1], "conv_fwd") == 0)
{ {
return ck::profiler::profile_convnd_fwd(argc, argv); pass = profile_convnd_fwd(argc, argv);
} }
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{ {
return profile_conv_fwd_bias_relu(argc, argv); pass = profile_conv_fwd_bias_relu(argc, argv);
} }
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0) else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{ {
return profile_conv_fwd_bias_relu_add(argc, argv); pass = profile_conv_fwd_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
} }
else if(strcmp(argv[1], "conv1d_bwd_data") == 0) else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{ {
return profile_convnd_bwd_data(argc, argv, 1); pass = profile_convnd_bwd_data(argc, argv, 1);
} }
else if(strcmp(argv[1], "conv2d_bwd_data") == 0) else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{ {
return profile_convnd_bwd_data(argc, argv, 2); pass = profile_convnd_bwd_data(argc, argv, 2);
} }
else if(strcmp(argv[1], "conv3d_bwd_data") == 0) else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{ {
return profile_convnd_bwd_data(argc, argv, 3); pass = profile_convnd_bwd_data(argc, argv, 3);
} }
else if(strcmp(argv[1], "reduce") == 0) else if(strcmp(argv[1], "reduce") == 0)
{ {
return profile_reduce(argc, argv); pass = profile_reduce(argc, argv);
} }
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0) else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{ {
return profile_conv_bwd_weight(argc, argv); pass = profile_conv_bwd_weight(argc, argv);
} }
else else
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (gemm: GEMM\n" printf("arg1: tensor operation, gemm: GEMM\n"
" gemm_splitk: GEMMSplitK)\n"
" gemm_bias_2d: GEMM+Bias(2D)\n" " gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n" " gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
...@@ -103,13 +105,13 @@ int main(int argc, char* argv[]) ...@@ -103,13 +105,13 @@ int main(int argc, char* argv[])
" conv_fwd: ForwardConvolution\n" " conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"
" conv1d_bwd_data: BackwardConvolution data 1 dim\n" " conv1d_bwd_data: BackwardConvolution data 1 dim\n"
" conv2d_bwd_data: BackwardConvolution data 2 dim\n" " conv2d_bwd_data: BackwardConvolution data 2 dim\n"
" conv3d_bwd_data: BackwardConvolution data 3 dim\n" " conv3d_bwd_data: BackwardConvolution data 3 dim\n"
" reduce: REDUCE\n" " reduce: Reduce\n"
" conv2d_bwd_weight: Backward Weight Convolution 2d\n"); " conv2d_bwd_weight: Backward Weight Convolution 2d\n");
// clang-format on // clang-format on
} }
return 0;
return pass ? 0 : 1;
} }
...@@ -52,7 +52,7 @@ add_subdirectory(space_filling_curve) ...@@ -52,7 +52,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util) add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd) add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_split_k) add_subdirectory(gemm_splitk)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_reduce)
......
...@@ -31,6 +31,7 @@ namespace ck { ...@@ -31,6 +31,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
...@@ -39,6 +40,7 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( ...@@ -39,6 +40,7 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance } // namespace device_gemm_instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -33,11 +33,6 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo ...@@ -33,11 +33,6 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
...@@ -63,8 +58,6 @@ int main() ...@@ -63,8 +58,6 @@ int main()
std::vector<DeviceGemmNoOpPtr> gemmPtrs; std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
...@@ -85,8 +78,6 @@ int main() ...@@ -85,8 +78,6 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
...@@ -107,8 +98,6 @@ int main() ...@@ -107,8 +98,6 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
...@@ -129,8 +118,6 @@ int main() ...@@ -129,8 +118,6 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
......
...@@ -31,16 +31,12 @@ namespace ck { ...@@ -31,16 +31,12 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
...@@ -64,8 +60,6 @@ int main() ...@@ -64,8 +60,6 @@ int main()
std::vector<DeviceGemmNoOpPtr> gemmPtrs; std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
...@@ -86,8 +80,6 @@ int main() ...@@ -86,8 +80,6 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
...@@ -108,8 +100,6 @@ int main() ...@@ -108,8 +100,6 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
...@@ -130,8 +120,6 @@ int main() ...@@ -130,8 +120,6 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
......
...@@ -31,14 +31,12 @@ namespace ck { ...@@ -31,14 +31,12 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance } // namespace device_gemm_instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
...@@ -57,7 +55,7 @@ int main() ...@@ -57,7 +55,7 @@ int main()
bool res = true; bool res = true;
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
...@@ -75,7 +73,7 @@ int main() ...@@ -75,7 +73,7 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
...@@ -93,7 +91,7 @@ int main() ...@@ -93,7 +91,7 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
...@@ -111,7 +109,7 @@ int main() ...@@ -111,7 +109,7 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs) for(auto& gemmPtr : gemmPtrs)
{ {
......
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
target_link_libraries(test_gemm_split_k PRIVATE host_tensor)
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance)
add_test_executable(test_gemm_splitk_fp32 gemm_splitk_fp32.cpp)
target_link_libraries(test_gemm_splitk_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_splitk_fp32 PRIVATE device_gemm_splitk_instance)
...@@ -35,6 +35,11 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic ...@@ -35,6 +35,11 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance } // namespace device_gemm_instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
...@@ -57,7 +62,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -57,7 +62,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
return true; return true;
} }
struct gemmArgs struct GemmArgs
{ {
GemmMatrixLayout layout; GemmMatrixLayout layout;
int M; int M;
...@@ -69,7 +74,7 @@ struct gemmArgs ...@@ -69,7 +74,7 @@ struct gemmArgs
int KBatch; int KBatch;
}; };
int test_gemm(const gemmArgs& args) int test_gemm(const GemmArgs& args)
{ {
bool a_row_major, b_row_major, c_row_major; bool a_row_major, b_row_major, c_row_major;
...@@ -213,7 +218,7 @@ int test_gemm(const gemmArgs& args) ...@@ -213,7 +218,7 @@ int test_gemm(const gemmArgs& args)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
std::vector<gemmArgs> test_cases; std::vector<GemmArgs> test_cases;
if(argc == 1) if(argc == 1)
{ {
test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}}; test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment