Commit 821308ae authored by Jing Zhang's avatar Jing Zhang
Browse files

updated regular gemm

parent d52ec016
...@@ -94,7 +94,6 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -94,7 +94,6 @@ bool profile_gemm_splitk_impl(int do_verification,
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.SetZero();
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
BLayout, BLayout,
...@@ -136,77 +135,99 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -136,77 +135,99 @@ bool profile_gemm_splitk_impl(int do_verification,
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances // profile device GEMM instances
for(auto& op_ptr : op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 36, 40, 60,
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), 64, 72, 80, 88, 96, 128, 144, 160, 176, 192, 256};
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), if(KBatch > 0)
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
KBatch);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling next kernel kbatch_list = {KBatch};
c_device_buf.SetZero(); }
std::string op_name = op_ptr->GetTypeString(); for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
kbatch_curr);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = // re-init C to zero before profiling next kernel
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); c_device_buf.SetZero();
std::size_t flop = std::size_t(2) * M * N * K; invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
std::size_t num_btype = if(do_verification)
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; {
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; std::string op_name = op_ptr->GetTypeString();
float gb_per_sec = num_btype / 1.E6 / ave_time; float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " std::size_t flop = std::size_t(2) * M * N * K;
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
{ sizeof(CDataType) * M * N;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification) float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); float gb_per_sec = num_btype / 1.E6 / ave_time;
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(do_log) if(tflops > best_tflops)
{ {
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; best_op_name = op_name;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; best_tflops = tflops;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",") best_ave_time = ave_time;
<< std::endl; best_gb_per_sec = gb_per_sec;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") best_kbatch = kbatch_curr;
<< std::endl;
} }
} }
} else
else {
{ std::cout << op_ptr->GetTypeString() << " does not support this problem"
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; << std::endl;
}
} }
} }
...@@ -246,7 +267,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -246,7 +267,7 @@ bool profile_gemm_splitk_impl(int do_verification,
} }
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << KBatch << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl; << " GB/s, " << best_op_name << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment