Commit df6f43d8 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 25751e37
...@@ -84,16 +84,16 @@ install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINAT ...@@ -84,16 +84,16 @@ install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINAT
set(PROFILER_SOURCE set(PROFILER_SOURCE
profiler.cpp profiler.cpp
profile_gemm.cpp profile_gemm.cpp
#profile_conv_fwd.cpp profile_conv_fwd.cpp
#profile_conv_fwd_bias_relu.cpp profile_conv_fwd_bias_relu.cpp
#profile_conv_fwd_bias_relu_add.cpp profile_conv_fwd_bias_relu_add.cpp
#profile_conv_fwd_bias_relu_atomic_add.cpp profile_conv_fwd_bias_relu_atomic_add.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
#target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
...@@ -178,72 +178,96 @@ void profile_gemm_impl(int do_verification, ...@@ -178,72 +178,96 @@ void profile_gemm_impl(int do_verification,
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs; std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmNoOpPtr> gemm_ptrs;
if(KBatch > 1 && is_same<ADataType, float>::value) if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
{
// ck::tensor_operation::device::device_gemm_instance::
// add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
// gemm_ptrs);
}
else
{
if(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value) is_same<CDataType, float>::value)
{ {
if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
} }
else if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
} }
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
} }
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
if(KBatch > 1)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
else
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
} }
} }
else if(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && }
else if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
{ {
if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
} }
else if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
} }
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
} }
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
...@@ -251,7 +275,6 @@ void profile_gemm_impl(int do_verification, ...@@ -251,7 +275,6 @@ void profile_gemm_impl(int do_verification,
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
} }
} }
}
if(gemm_ptrs.size() <= 0) if(gemm_ptrs.size() <= 0)
{ {
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#include <half.hpp> #include <half.hpp>
int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
// int profile_conv_fwd(int, char*[]); int profile_conv_fwd(int, char*[]);
// int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]);
// int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]);
// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -17,7 +17,6 @@ int main(int argc, char* argv[]) ...@@ -17,7 +17,6 @@ int main(int argc, char* argv[])
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
#if 0
else if(strcmp(argv[1], "conv_fwd") == 0) else if(strcmp(argv[1], "conv_fwd") == 0)
{ {
return profile_conv_fwd(argc, argv); return profile_conv_fwd(argc, argv);
...@@ -34,7 +33,6 @@ int main(int argc, char* argv[]) ...@@ -34,7 +33,6 @@ int main(int argc, char* argv[])
{ {
return profile_conv_fwd_bias_relu_atomic_add(argc, argv); return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
} }
#endif
else else
{ {
printf("arg1: tensor operation (gemm: GEMM;\n" printf("arg1: tensor operation (gemm: GEMM;\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment