Commit 99cc8431 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 7d69eb3b
......@@ -187,101 +187,123 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
KBatch,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38};
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
if(do_verification)
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
KBatch,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_log)
if(do_verification)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
std::string op_name = op_ptr->GetTypeString();
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(
argument_ptr.get(),
StreamConfig{
nullptr, time_kernel, 0, n_warmup, n_iter, rotating_count > 1, rotating_count});
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl;
#if defined CK_ENABLE_FP8
// set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>)
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
// set softer tolerances for fp8
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>)
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8
}
}
#endif
if(tflops > best_tflops)
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
}
if constexpr(is_same<EDataType, float>::value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment