"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "cb538740021d5a41d89d9c21ec2eb28b2088d046"
Commit f99f614e authored by Chao Liu's avatar Chao Liu
Browse files

update profiler

parent b238662a
...@@ -75,6 +75,7 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -75,6 +75,7 @@ int profile_gemm_gelu_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
...@@ -101,16 +102,9 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -101,16 +102,9 @@ int profile_gemm_gelu_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmGeluPtr> gemm_ptrs; std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmGeluPtr>
device_op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<CDataType, half_t>)
...@@ -120,69 +114,87 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -120,69 +114,87 @@ int profile_gemm_gelu_impl(int do_verification,
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> && is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> && is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> && is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(device_op_ptrs);
} }
} }
if(gemm_ptrs.size() <= 0) std::cout << "found " << device_op_ptrs.size() << " instances" << std::endl;
// run reference
if(do_verification)
{ {
throw std::runtime_error("wrong! no device operation instance found"); using ReferenceOpInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
} }
std::string best_gemm_name; DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
std::string best_device_op_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
bool pass = true; bool pass = true;
// profile device GEMM instances // profile device operation instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& device_op_ptr : device_op_ptrs)
{ {
auto argument_ptr = auto argument_ptr = device_op_ptr->MakeArgumentPointer(
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = device_op_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) std::string device_op_name = device_op_ptr->GetTypeString();
if(device_op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel // re-init C to zero before profiling a kernel
c_device_buf.SetZero(); c_device_buf.SetZero();
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -196,37 +208,20 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -196,37 +208,20 @@ int profile_gemm_gelu_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << gemm_name << std::endl; << gb_per_sec << " GB/s, " << device_op_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_gemm_name = gemm_name; best_device_op_name = device_op_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
} }
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
pass = pass && pass = pass &&
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
...@@ -243,12 +238,12 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -243,12 +238,12 @@ int profile_gemm_gelu_impl(int do_verification,
} }
else else
{ {
std::cout << "does not support this problem" << std::endl; std::cout << device_op_name << " does not support this problem" << std::endl;
} }
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_device_op_name << std::endl;
return pass ? 0 : 1; return pass ? 0 : 1;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment