Commit 37d83d7d authored by Anthony Chang's avatar Anthony Chang
Browse files

refactor standalone test to use gemm test harness

parent 06b650d2
......@@ -38,8 +38,10 @@ using CElementOp = PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using ck::gemm_util::GemmParams;
using ck::tensor_operation::device::BaseOperator;
using namespace ck::tensor_operation::device;
using ck::tensor_operation::device::DeviceGemm;
using namespace ck::tensor_operation::device::instance;
using DeviceGemmNN =
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>;
......@@ -50,24 +52,6 @@ using DeviceGemmTN =
using DeviceGemmTT =
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>;
struct ProblemSize
{
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA;
ck::index_t StrideB;
ck::index_t StrideC;
};
struct ExecutionConfig
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
struct LayoutConfig
{
bool ARowMajor;
......@@ -75,28 +59,6 @@ struct LayoutConfig
bool CRowMajor;
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
bool run_gemm(const ProblemSize& problem_size,
const ExecutionConfig& config,
ck::tensor_operation::device::DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>* gemm_instance_ptr);
int main(int argc, char* argv[])
{
// Class DeviceGemm is templated by layout and precision types so it is not an option to contain
......@@ -106,212 +68,100 @@ int main(int argc, char* argv[])
// LayoutConfig is used for determining which type a BaseOperator instance should be cast to.
using OpFactoryFn = void (*)(std::vector<std::unique_ptr<BaseOperator>>&);
const std::vector<std::tuple<ProblemSize, LayoutConfig, OpFactoryFn>> problems = {
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
// clang-format off
// 104 tiles
{ProblemSize{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_256x256},
{ProblemSize{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_256x128},
{ProblemSize{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_128x128},
{ProblemSize{1024, 832, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_128x64},
{ProblemSize{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_256x256},
{ProblemSize{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_256x128},
{ProblemSize{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_128x128},
{ProblemSize{1024, 832, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_128x64},
{ProblemSize{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_256x128},
{ProblemSize{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_256x128},
{ProblemSize{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_128x128},
{ProblemSize{1024, 832, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_128x64},
{ProblemSize{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_256x256},
{ProblemSize{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_256x128},
{ProblemSize{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_128x128},
{ProblemSize{1024, 832, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_128x64},
{GemmParams{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1024, 832, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1024, 832, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1024, 832, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2048, 3328, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2048, 1664, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1024, 1664, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1024, 832, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// 110 tiles
{ProblemSize{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_256x256},
{ProblemSize{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_256x128},
{ProblemSize{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_128x128},
{ProblemSize{1280, 704, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, instance::add_gemm_f16_nn_128x64},
{ProblemSize{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_256x256},
{ProblemSize{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_256x128},
{ProblemSize{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_128x128},
{ProblemSize{1280, 704, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, instance::add_gemm_f16_nt_128x64},
{ProblemSize{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_256x128},
{ProblemSize{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_256x128},
{ProblemSize{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_128x128},
{ProblemSize{1280, 704, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, instance::add_gemm_f16_tn_128x64},
{ProblemSize{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_256x256},
{ProblemSize{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_256x128},
{ProblemSize{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_128x128},
{ProblemSize{1280, 704, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, instance::add_gemm_f16_tt_128x64},
{GemmParams{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1280, 704, 4096, -1, -1, -1}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1280, 704, 4096, -1, -1, -1}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1280, 704, 4096, -1, -1, -1}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2560, 2816, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2560, 1408, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1280, 1408, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1280, 704, 4096, -1, -1, -1}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on
};
ExecutionConfig config{true, 1, true};
bool do_verification = true;
bool time_kernel = true;
if(argc == 4)
if(argc == 1)
{
// use default
}
else if(argc == 3)
{
do_verification = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
}
else
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: time kernel (0=no, 1=yes)" << std::endl;
}
for(auto& p : problems)
{
const ProblemSize& problem_size = std::get<0>(p);
GemmParams& problem_size = std::get<0>(p);
const LayoutConfig& layout_config = std::get<1>(p);
const auto& factory = std::get<2>(p);
std::vector<std::unique_ptr<BaseOperator>> ops;
factory(ops);
problem_size.StrideA = layout_config.ARowMajor ? problem_size.K : problem_size.M;
problem_size.StrideB = layout_config.BRowMajor ? problem_size.N : problem_size.K;
problem_size.StrideC = layout_config.CRowMajor ? problem_size.N : problem_size.M;
if(!layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNN*>(ops[0].get());
run_gemm(problem_size, config, op_ptr);
ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(!layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmNT*>(ops[0].get());
run_gemm(problem_size, config, op_ptr);
ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(layout_config.ARowMajor && !layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTN*>(ops[0].get());
run_gemm(problem_size, config, op_ptr);
ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
else if(layout_config.ARowMajor && layout_config.BRowMajor)
{
auto op_ptr = dynamic_cast<DeviceGemmTT*>(ops[0].get());
run_gemm(problem_size, config, op_ptr);
ck::gemm_util::TestGemm<AccDataType>{}(
op_ptr, problem_size, do_verification, time_kernel);
}
}
return 0;
}
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
bool run_gemm(const ProblemSize& problem_size,
const ExecutionConfig& config,
ck::tensor_operation::device::DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>* gemm_instance_ptr)
{
// using namespace ck::literals;
auto [M, N, K, StrideA, StrideB, StrideC] = problem_size;
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t& stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
stride = stride == -1 ? col : stride;
return HostTensorDescriptor({row, col}, {stride, 1});
}
else
{
stride = stride == -1 ? row : stride;
return HostTensorDescriptor({row, col}, {1, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k.begin(),
a_m_k.end());
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n.begin(),
b_k_n.end());
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k.begin(), a_m_k.end());
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n.begin(), b_k_n.end());
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto& gemm = *gemm_instance_ptr;
auto invoker = gemm.MakeInvokerPointer();
auto argument =
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument.get()))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
}
return true;
}
......@@ -16,21 +16,13 @@ namespace gemm_util {
struct GemmParams
{
GemmParams()
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
{
}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA;
ck::index_t StrideB;
ck::index_t StrideC;
float alpha;
float beta;
ck::index_t StrideA = 1024;
ck::index_t StrideB = 1024;
ck::index_t StrideC = 1024;
};
template <typename GemmInstance,
......@@ -69,7 +61,8 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
bool time_kernel)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
......@@ -94,7 +87,20 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
{
a_m_k_device_buf.ToDevice(A.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data());
invoker_ptr->Run(argument_ptr.get());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * params.M * params.N * params.K;
std::size_t num_btype = sizeof(ADataType) * params.M * params.K +
sizeof(BDataType) * params.K * params.N +
sizeof(CDataType) * params.M * params.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << std::endl;
c_m_n_device_buf.FromDevice(C.mData.data());
return true;
......@@ -109,19 +115,15 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
}
}
template <typename DeviceGemmPtr_,
typename ADataType,
template <typename AccDataType>
struct TestGemm
{
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct TestGemm
{
typename CLayout>
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
......@@ -156,25 +158,42 @@ struct TestGemm
f_generate_tensor_value(a_m_k, ADataType{});
f_generate_tensor_value(b_k_n, BDataType{});
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceGemmPtr_& gemmPtr)
template <template <class...> class DeviceGemmPtr_,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
auto operator()(DeviceGemmPtr_<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>* gemmPtr,
const GemmParams& params = GemmParams{},
bool do_verification = true,
bool time_kernel = false)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl;
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
auto host_tensors =
PrepareGemmTensor<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
......@@ -193,14 +212,18 @@ struct TestGemm
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
if(do_verification)
{
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
}
// Act
bool is_supported = ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op, time_kernel);
if(is_supported)
if(is_supported && do_verification)
{
// Assert
bool res = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment