Commit 25751e37 authored by Jing Zhang's avatar Jing Zhang
Browse files

add selection of device_instances

parent 8160c31a
......@@ -84,10 +84,11 @@ install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINAT
set(PROFILER_SOURCE
profiler.cpp
profile_gemm.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_fwd_bias_relu_atomic_add.cpp)
#profile_conv_fwd.cpp
#profile_conv_fwd_bias_relu.cpp
#profile_conv_fwd_bias_relu_add.cpp
#profile_conv_fwd_bias_relu_atomic_add.cpp
)
add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor)
......
......@@ -7,6 +7,11 @@ namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
#if 0
template <>
void add_device_gemm_instance<float,
......@@ -175,15 +180,77 @@ void profile_gemm_impl(int do_verification,
if(KBatch > 1 && is_same<ADataType, float>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
gemm_ptrs);
// ck::tensor_operation::device::device_gemm_instance::
// add_device_splitk_gemm_instance<float, float, float, ALayout, BLayout, CLayout>(
// gemm_ptrs);
}
else
{
if(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
else if(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_instance<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(
gemm_ptrs);
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
}
if(gemm_ptrs.size() <= 0)
......
......@@ -6,10 +6,10 @@
#include <half.hpp>
int profile_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
// int profile_conv_fwd(int, char*[]);
// int profile_conv_fwd_bias_relu(int, char*[]);
// int profile_conv_fwd_bias_relu_add(int, char*[]);
// int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int main(int argc, char* argv[])
{
......@@ -17,6 +17,7 @@ int main(int argc, char* argv[])
{
return profile_gemm(argc, argv);
}
#if 0
else if(strcmp(argv[1], "conv_fwd") == 0)
{
return profile_conv_fwd(argc, argv);
......@@ -33,6 +34,7 @@ int main(int argc, char* argv[])
{
return profile_conv_fwd_bias_relu_atomic_add(argc, argv);
}
#endif
else
{
printf("arg1: tensor operation (gemm: GEMM;\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment