"...composable_kernel_rocm.git" did not exist on "dbe0691104c4e8510ee3104150fe4b3ec5e1eeb4"
Commit 99cc8431 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 7d69eb3b
...@@ -187,101 +187,123 @@ bool profile_gemm_multiply_multiply_impl(int do_verification, ...@@ -187,101 +187,123 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
KBatch,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38};
c_device_buf.SetZero();
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter}); for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
if(do_verification) auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
KBatch,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); // re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_log) if(do_verification)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; c_device_buf.FromDevice(e_m_n_device_result.mData.data());
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",") pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",") if(do_log)
<< std::endl; {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
}
} }
}
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run( float ave_time = invoker_ptr->Run(argument_ptr.get(),
argument_ptr.get(), StreamConfig{nullptr,
StreamConfig{ time_kernel,
nullptr, time_kernel, 0, n_warmup, n_iter, rotating_count > 1, rotating_count}); 0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< gb_per_sec << " GB/s, " << op_name << std::endl; << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl;
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
// set softer tolerances for fp8 // set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> || if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) is_same_v<EDataType, f8_t>)
{ {
std::string msg = "Error: Incorrect results!"; std::string msg = "Error: Incorrect results!";
double rtol = 1e-1; double rtol = 1e-1;
double atol = 1e-1; double atol = 1e-1;
pass = pass & ck::utils::check_err( pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
} }
else else
{ {
#endif #endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
} }
#endif #endif
if(tflops > best_tflops) if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{ {
best_op_name = op_name; std::cout << op_ptr->GetTypeString() << " does not support this problem"
best_tflops = tflops; << std::endl;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
} }
} }
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
} }
if constexpr(is_same<EDataType, float>::value) if constexpr(is_same<EDataType, float>::value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment