Commit 60b12489 authored by wangshaojie6's avatar wangshaojie6
Browse files

add dynamic kernel example

parent 16467e0e
......@@ -207,6 +207,7 @@ int main(int argc, char* argv[])
std::cout << "b device buf: " << b_k_n_device_buf.GetDeviceBuffer() << std::endl;
std::cout << "c device buf: " << c_m_n_device_buf.GetDeviceBuffer() << std::endl;
#if USEING_STATIC_KERNEL
// do GEMM
if(M == 16 && N == 1152 && K == 5120 && splitk == 8)
{
......@@ -421,5 +422,63 @@ int main(int argc, char* argv[])
}
}
#else
// dynamic kernel
{
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
splitk);
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
}
}
#endif
return 0;
}
......@@ -48,7 +48,7 @@
#include "amd_xdlops.hpp"
#endif
#define USEING_STATIC_KERNEL 1
#define USEING_STATIC_KERNEL 0
#define MNKB_0_8 0
#define MNKB_1_4 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment