Commit 25751e37 authored by Jing Zhang's avatar Jing Zhang
Browse files

add selection of device_instances

parent 8160c31a
...@@ -84,10 +84,11 @@ install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINAT ...@@ -84,10 +84,11 @@ install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINAT
set(PROFILER_SOURCE set(PROFILER_SOURCE
profiler.cpp profiler.cpp
profile_gemm.cpp profile_gemm.cpp
profile_conv_fwd.cpp #profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp #profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp #profile_conv_fwd_bias_relu_add.cpp
profile_conv_fwd_bias_relu_atomic_add.cpp) #profile_conv_fwd_bias_relu_atomic_add.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
......
...@@ -7,6 +7,11 @@ namespace tensor_operation { ...@@ -7,6 +7,11 @@ namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
#if 0 #if 0
template <> template <>
void add_device_gemm_instance<float, void add_device_gemm_instance<float,
...@@ -175,15 +180,77 @@ void profile_gemm_impl(int do_verification, ...@@ -175,15 +180,77 @@ void profile_gemm_impl(int do_verification,
if(KBatch > 1 && is_same<ADataType, float>::value) if(KBatch > 1 && is_same<ADataType, float>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: // ck::tensor_operation::device::device_gemm_instance::
add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>( // add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
gemm_ptrs); // gemm_ptrs);
} }
else else
{ {
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>( if(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
gemm_ptrs); is_same<CDataType, float>::value)
{
if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
else if(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
} }
if(gemm_ptrs.size() <= 0) if(gemm_ptrs.size() <= 0)
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#include <half.hpp> #include <half.hpp>
int profile_gemm(int, char*[]); int profile_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]); // int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]); // int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); // int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); // int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
...@@ -17,6 +17,7 @@ int main(int argc, char* argv[]) ...@@ -17,6 +17,7 @@ int main(int argc, char* argv[])
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
#if 0
else if(strcmp(argv[1], "conv_fwd") == 0) else if(strcmp(argv[1], "conv_fwd") == 0)
{ {
return profile_conv_fwd(argc, argv); return profile_conv_fwd(argc, argv);
...@@ -33,6 +34,7 @@ int main(int argc, char* argv[]) ...@@ -33,6 +34,7 @@ int main(int argc, char* argv[])
{ {
return profile_conv_fwd_bias_relu_atomic_add(argc, argv); return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
} }
#endif
else else
{ {
printf("arg1: tensor operation (gemm: GEMM;\n" printf("arg1: tensor operation (gemm: GEMM;\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment