Commit 4511f877 authored by Chao Liu's avatar Chao Liu
Browse files

refactor profiler

parent 519b6aaf
......@@ -4,97 +4,99 @@
#include <cstdlib>
#include <cstring>
#include "profile_convnd_fwd.hpp"
int profile_gemm(int, char*[]);
int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int profile_convnd_bwd_data(int, char*[], int);
int profile_reduce(int, char*[]);
int profile_conv_bwd_weight(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]);
bool profile_gemm(int, char*[]);
bool profile_gemm_splitk(int, char*[]);
bool profile_gemm_bias_2d(int, char*[]);
bool profile_gemm_bias_relu(int, char*[]);
bool profile_gemm_bias_relu_add(int, char*[]);
bool profile_gemm_reduce(int, char*[]);
bool profile_batched_gemm(int, char*[]);
bool profile_grouped_gemm(int, char*[]);
bool profile_conv_fwd_bias_relu(int, char*[]);
bool profile_conv_fwd_bias_relu_add(int, char*[]);
bool profile_convnd_fwd(int argc, char* argv[]);
bool profile_convnd_bwd_data(int, char*[], int);
bool profile_reduce(int, char*[]);
bool profile_conv_bwd_weight(int, char*[]);
bool profile_batched_gemm_reduce(int, char*[]);
int main(int argc, char* argv[])
{
bool pass = true;
if(strcmp(argv[1], "gemm") == 0)
{
return profile_gemm(argc, argv);
pass = profile_gemm(argc, argv);
}
if(strcmp(argv[1], "gemm_splitk") == 0)
{
pass = profile_gemm_splitk(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
return profile_gemm_bias_2d(argc, argv);
pass = profile_gemm_bias_2d(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_relu") == 0)
{
return profile_gemm_bias_relu(argc, argv);
pass = profile_gemm_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_relu_add") == 0)
{
return profile_gemm_bias_relu_add(argc, argv);
pass = profile_gemm_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "gemm_reduce") == 0)
{
return profile_gemm_reduce(argc, argv);
pass = profile_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm") == 0)
{
return profile_batched_gemm(argc, argv);
pass = profile_batched_gemm(argc, argv);
}
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
{
return profile_batched_gemm_reduce(argc, argv);
pass = profile_batched_gemm_reduce(argc, argv);
}
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
profile_grouped_gemm(argc, argv);
pass = profile_grouped_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return ck::profiler::profile_convnd_fwd(argc, argv);
pass = profile_convnd_fwd(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu") == 0)
{
return profile_conv_fwd_bias_relu(argc, argv);
pass = profile_conv_fwd_bias_relu(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_add") == 0)
{
return profile_conv_fwd_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd_bias_relu_atomic_add") == 0)
{
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
pass = profile_conv_fwd_bias_relu_add(argc, argv);
}
else if(strcmp(argv[1], "conv1d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 1);
pass = profile_convnd_bwd_data(argc, argv, 1);
}
else if(strcmp(argv[1], "conv2d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 2);
pass = profile_convnd_bwd_data(argc, argv, 2);
}
else if(strcmp(argv[1], "conv3d_bwd_data") == 0)
{
return profile_convnd_bwd_data(argc, argv, 3);
pass = profile_convnd_bwd_data(argc, argv, 3);
}
else if(strcmp(argv[1], "reduce") == 0)
{
return profile_reduce(argc, argv);
pass = profile_reduce(argc, argv);
}
else if(strcmp(argv[1], "conv2d_bwd_weight") == 0)
{
return profile_conv_bwd_weight(argc, argv);
pass = profile_conv_bwd_weight(argc, argv);
}
else
{
// clang-format off
printf("arg1: tensor operation (gemm: GEMM\n"
printf("arg1: tensor operation, gemm: GEMM\n"
" gemm_splitk: GEMMSplitK)\n"
" gemm_bias_2d: GEMM+Bias(2D)\n"
" gemm_bias_relu: GEMM+Bias+ReLU\n"
" gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n"
......@@ -103,13 +105,13 @@ int main(int argc, char* argv[])
" conv_fwd: ForwardConvolution\n"
" conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n"
" conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n"
" conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"
" conv1d_bwd_data: BackwardConvolution data 1 dim\n"
" conv2d_bwd_data: BackwardConvolution data 2 dim\n"
" conv3d_bwd_data: BackwardConvolution data 3 dim\n"
" reduce: REDUCE\n"
" reduce: Reduce\n"
" conv2d_bwd_weight: Backward Weight Convolution 2d\n");
// clang-format on
}
return 0;
return pass ? 0 : 1;
}
......@@ -52,7 +52,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_splitk)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
......
......@@ -31,6 +31,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
......@@ -39,6 +40,7 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
......
......@@ -33,11 +33,6 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
......@@ -63,8 +58,6 @@ int main()
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
......@@ -85,8 +78,6 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
......@@ -107,8 +98,6 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
......@@ -129,8 +118,6 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
......
......@@ -31,16 +31,12 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
......@@ -64,8 +60,6 @@ int main()
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
......@@ -86,8 +80,6 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
......@@ -108,8 +100,6 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
......@@ -130,8 +120,6 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
......
......@@ -31,14 +31,12 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
......@@ -57,7 +55,7 @@ int main()
bool res = true;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
......@@ -75,7 +73,7 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
......@@ -93,7 +91,7 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
......@@ -111,7 +109,7 @@ int main()
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs);
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
......
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
target_link_libraries(test_gemm_split_k PRIVATE host_tensor)
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance)
add_test_executable(test_gemm_splitk_fp32 gemm_splitk_fp32.cpp)
target_link_libraries(test_gemm_splitk_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_splitk_fp32 PRIVATE device_gemm_splitk_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment