Unverified Commit 061ac064 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Polished Grouped GEMM APIs and new BF16 instances (#1600)

* Few small fixes.

* New GroupedGemm instances (BF16)

* Unify and refactor GroupedGEMM device API.

* Adapt changes to new API.

* Adapt grouped gemm profiler.

* Accept multiple kbatches for grouped gemm profiler.

- delete obsolete two stage as it is now covered by grouped gemm

* Update unit test for grouped gemm.

* Fix thresholds for BF16 and F8. Unblock tests.

* Fix few instances.

* Multiple small fixes.

* Adapt to new API, check dynamic casting.

* Uncomment few data types in grouped gemm profiler.

* Fix call to SetDeviceArgs.

* Fix profile grouped gemm multiply tile loop.

* Fix grouped gemm tile loop kernel args in client examples.

* Review comments.
parent cb8c7f42
...@@ -127,7 +127,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, ...@@ -127,7 +127,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
p_b.reserve(group_count); p_b.reserve(group_count);
p_c.reserve(group_count); p_c.reserve(group_count);
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<>; using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<KernelArguments> gemm_kargs; std::vector<KernelArguments> gemm_kargs;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_grouped_gemm_two_stage_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_c,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString();
using DeviceOpSplitK =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// skip non-splitk grouped_gemm
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
{
continue;
}
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(kbatch > 0)
{
kbatch_list = {kbatch};
}
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_arg_dev_mem(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->GetDeviceKernelArgSize(argument_ptr.get()));
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
gemm_desc_workspace.SetZero();
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
}
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
<< std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck
...@@ -43,7 +43,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -43,7 +43,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -39,16 +39,13 @@ namespace { ...@@ -39,16 +39,13 @@ namespace {
std::vector<int> argToIntArray(char* input) std::vector<int> argToIntArray(char* input)
{ {
std::vector<int> out; std::vector<int> out;
std::istringstream in(input); std::istringstream in(input);
std::string item; std::string item;
while(std::getline(in, item, ',')) while(std::getline(in, item, ','))
{ {
out.push_back(std::stoi(item)); out.push_back(std::stoi(item));
} }
return out; return out;
} }
...@@ -69,7 +66,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -69,7 +66,7 @@ int profile_grouped_gemm(int argc, char* argv[])
<< "arg7: time kernel (0=n0, 1=yes)\n" << "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n" "64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 1)\n" << "arg15: kbatch values (default 1)\n"
<< "optional:\n" << "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n" << "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n" << "arg17: number of iterations (default 10)\n"
...@@ -92,7 +89,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -92,7 +89,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]); const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]); const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; const auto kbatches = argc >= 15 ? argToIntArray(argv[14]) : std::vector<int>{};
int n_warmup = 1; int n_warmup = 1;
int n_iter = 10; int n_iter = 10;
...@@ -102,7 +99,6 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -102,7 +99,6 @@ int profile_grouped_gemm(int argc, char* argv[])
n_iter = std::stoi(argv[16]); n_iter = std::stoi(argv[16]);
} }
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
ck::profiler::profile_grouped_gemm_impl<ck::half_t, ck::profiler::profile_grouped_gemm_impl<ck::half_t,
...@@ -121,7 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -121,7 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup,
n_iter); n_iter);
} }
...@@ -143,7 +139,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -143,7 +139,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup,
n_iter); n_iter);
} }
...@@ -165,7 +161,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -165,7 +161,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup,
n_iter); n_iter);
} }
...@@ -187,7 +183,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -187,7 +183,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup,
n_iter); n_iter);
} }
...@@ -209,7 +205,7 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -209,7 +205,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup,
n_iter); n_iter);
} }
...@@ -231,7 +227,73 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -231,7 +227,73 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup, n_warmup,
n_iter); n_iter);
} }
...@@ -239,7 +301,6 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -239,7 +301,6 @@ int profile_grouped_gemm(int argc, char* argv[])
{ {
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
} }
#endif
return 0; return 0;
} }
......
...@@ -32,9 +32,7 @@ namespace { ...@@ -32,9 +32,7 @@ namespace {
std::vector<int> argToIntArray(char* input) std::vector<int> argToIntArray(char* input)
{ {
std::vector<int> out; std::vector<int> out;
std::istringstream in(input); std::istringstream in(input);
std::string item; std::string item;
while(std::getline(in, item, ',')) while(std::getline(in, item, ','))
...@@ -83,7 +81,7 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -83,7 +81,7 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]); const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]); const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
...@@ -97,8 +95,8 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) ...@@ -97,8 +95,8 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
int n_iter = 10; int n_iter = 10;
if(argc == 17) if(argc == 17)
{ {
n_warmup = std::stoi(argv[16]); n_warmup = std::stoi(argv[15]);
n_iter = std::stoi(argv[17]); n_iter = std::stoi(argv[16]);
} }
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
};
enum struct GemmDataType
{
F16_F16_F16, // 0
BF16_INT8_BF16, // 1
BF16_BF16_BF16 // 2
};
#define OP_NAME "grouped_gemm_two_stage"
#define OP_DESC "Grouped GEMM TwoStage"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_two_stage(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp16; 1: bf16@int8; 2: bf16)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 1)\n"
<< "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
auto StrideAs = argToIntArray(argv[11]);
auto StrideBs = argToIntArray(argv[12]);
auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const int DefaultStrideA = Ks[0];
const int DefaultStrideB = Ns[0];
const int DefaultStrideC = Ns[0];
for(size_t i = 0; i < Ms.size(); ++i)
{
StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i];
StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i];
StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i];
}
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
}
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_two_stage);
...@@ -6,12 +6,6 @@ if(result EQUAL 0) ...@@ -6,12 +6,6 @@ if(result EQUAL 0)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif() endif()
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -10,25 +10,35 @@ ...@@ -10,25 +10,35 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp" #include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; template <typename Tuple>
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>; class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>; };
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// clang-format off
const std::vector<int> KBATCH{1, 2, 3, 5, 8}; using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); std::tuple< Row, Col, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); std::tuple< Col, Row, Row, F16, F16, F16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN, std::tuple< Col, Col, Row, F16, F16, F16>,
RRR_F16_F16_F16_LargeK, std::tuple< Row, Row, Row, BF16, BF16, BF16>,
testing::Values(32, 64)); std::tuple< Row, Col, Row, BF16, BF16, BF16>,
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK, std::tuple< Col, Row, Row, BF16, BF16, BF16>,
RCR_F16_F16_F16_LargeK, std::tuple< Row, Row, Row, BF16, I8, BF16>,
testing::Values(32, 64)); std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc" #include "test_grouped_gemm_ut_cases.inc"
#pragma once #pragma once
TEST_P(RRR_F16_F16_F16, TinyCases) TYPED_TEST(TestGroupedGemm, TinyCases)
{ {
const std::vector<int> Ms{0, 1}; const std::vector<int> Ms{0, 1};
constexpr int N = 768; constexpr int N = 768;
...@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases) ...@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, SmallCases) TYPED_TEST(TestGroupedGemm, SmallCases)
{ {
const std::vector<int> Ms{2, 1, 3, 4, 5, 0}; const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768; constexpr int N = 768;
...@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases) ...@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, MidCases) TYPED_TEST(TestGroupedGemm, MidCases)
{ {
const std::vector<int> Ms{167, 183, 177, 153, 139, 204}; const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768; constexpr int N = 768;
...@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases) ...@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, Regular) TYPED_TEST(TestGroupedGemm, Regular)
{ {
const std::vector<int> Ms{64, 128, 256}; const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768; constexpr int N = 768;
...@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular) ...@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RRR_F16_F16_F16, MNKPadded) TYPED_TEST(TestGroupedGemm, MNKPadded)
{ {
const std::vector<int> Ms{127, 150, 188, 210}; const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136; constexpr int N = 136;
...@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded) ...@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
TEST_P(RCR_F16_F16_F16, TinyCases) TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
{ {
const std::vector<int> Ms{188, 210}; const std::vector<int> Ms{188, 210};
constexpr int N = 768; constexpr int N = 768;
...@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) ...@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> Ns(Ms.size(), N); const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K); const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) this->k_batches_ = {32, 64};
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); this->Run(Ms, Ns, Ks);
} }
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck { namespace ck {
namespace test { namespace test {
...@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range) ...@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range)
} }
template <typename Tuple> template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int> class TestGroupedGemm : public testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
...@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>; using EDataType = std::tuple_element_t<5, Tuple>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public: public:
static constexpr bool verify_ = true; static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false; static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
std::vector<int> k_batches_;
void SetUp() override {} void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms, void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns, const std::vector<int>& Ns,
const std::vector<int>& Ks, const std::vector<int>& Ks,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs = {})
int kbatch = 1, {
int n_warmup = 1, std::vector<int> stride_as = StrideAs;
int n_iter = 10) std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches_);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{ {
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType, bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType, BDataType,
...@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch, kbatches,
n_warmup, n_warmup_,
n_iter); n_iter_);
EXPECT_TRUE(pass);
}
};
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
}; };
...@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1) if(kbatch > 1)
{ {
ggemm_instance.SetKBatchSize(argument, kbatch); ggemm_instance.SetKBatchSize(&argument, kbatch);
} }
return ggemm_instance.IsSupportedArgument(argument); return ggemm_instance.IsSupportedArgument(argument);
...@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper ...@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1) if(kbatch > 1)
{ {
ggemm_instance.SetKBatchSize(argument, kbatch); ggemm_instance.SetKBatchSize(&argument, kbatch);
} }
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker(); auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument)); DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false}); return invoker.Run(argument, StreamConfig{nullptr, false});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment