Commit e87f7319 authored by Jing Zhang's avatar Jing Zhang
Browse files

dedicated fixed_nk solution

parent 5a5468f4
......@@ -60,8 +60,6 @@ int main()
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0;
for(int i = 0; i < group_count; ++i)
{
Ms.push_back(256 + 256 * distrib(gen));
......@@ -71,8 +69,6 @@ int main()
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
sum_of_m += Ms[i];
}
auto f_matrix_space_size =
......@@ -106,10 +102,6 @@ int main()
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
a_dev_bufs.emplace_back(sizeof(ADataType) *
......@@ -119,23 +111,11 @@ int main()
e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}});
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}});
p_a.push_back(a_dev_bufs[i].GetDeviceBuffer());
p_b.push_back(b_dev_bufs[i].GetDeviceBuffer());
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
{},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
......@@ -182,20 +162,13 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
// op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string op_name = op_ptr->GetTypeString();
hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetWorkSpaceSize(argument_ptr.get()),
hipMemcpyHostToDevice);
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(),
gemm_desc_workspace.GetDeviceBuffer(),
StreamConfig{nullptr, true});
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = 0, num_btype = 0;
for(std::size_t j = 0; j < gemm_descs.size(); ++j)
......
......@@ -7,6 +7,8 @@ add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32
......@@ -14,7 +16,9 @@ add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16)
example_grouped_gemm_xdl_splitk_fp16
example_grouped_gemm_xdl_fixed_nk_fp16
)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero();
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{}});
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetWorkSpaceSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
......@@ -10,7 +10,6 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -47,7 +46,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
......@@ -58,268 +57,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
#include "run_grouped_gemm_example.inc"
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero();
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{}});
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetWorkSpaceSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
......@@ -20,24 +20,6 @@ struct GemmDesc
std::vector<ck::index_t> stride_Ds_;
};
template <index_t NumDTensor = 0>
struct GroupedGemmKernelArgument
{
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
......@@ -66,8 +48,6 @@ struct DeviceGroupedGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
};
} // namespace device
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDTensor = 0>
struct GroupedGemmKernelArgument
{
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -24,13 +24,6 @@ namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap,
typename GroupedGemmBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
......@@ -41,7 +34,6 @@ __global__ void
#endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
......@@ -55,7 +47,6 @@ __global__ void
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
......@@ -73,14 +64,7 @@ __global__ void
}
group_id = index_t((left + right) / 2);
}
#endif
const index_t group_id = block_id / grid_size_grp;
if(group_id >= group_count)
return;
#if 0
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
......@@ -95,83 +79,6 @@ __global__ void
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#else
const index_t M = gemm_desc_ptr[group_id].M;
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M == 0 || N == 0 || K == 0)
return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
const auto StrideE = gemm_desc_ptr[group_id].StrideE;
#if 0
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
#endif
using DsDataType = ck::Tuple<>;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp;
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n};
constexpr auto NumDTensor = 0;
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid_;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
auto m_loops = local_b2e_tile_map.CalculateMLoops();
index_t m_id = 0;
do
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, m_id);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
m_id += 1;
} while(m_id < m_loops);
#endif
#else
ignore = gemm_descs_const;
ignore = group_count;
......@@ -374,162 +281,54 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
struct GroupedGemmBlock2ETileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
__host__ __device__
OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start,
index_t mblock_id_off = 0)
GroupedGemmBlock2ETileMap()
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
mblock_id_off_ = mblock_id_off;
block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{});
BlockStart_ = -1;
}
GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
{
block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
BlockStart_ = BlockStart;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
return make_tuple(idx_bot[Number<0>{}] + mblock_id_off_, idx_bot[Number<1>{}]);
return block_2_etile_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] - BlockStart_));
}
// it's actually E-Tile
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
index_t mblock_id_off_;
};
template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_M00_N0_M01Adapt_MLoops
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt_MLoops(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ __device__ constexpr index_t CalculateMLoops() const
{
return math::integer_divide_ceil(M_, MPerBlock_);
}
__host__ static constexpr index_t CalculateGridSize(index_t /*M*/, index_t N)
{
const auto M0 = 1; // math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = 1; // math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
__host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
}
private:
index_t M_;
index_t N_;
index_t M01_;
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_;
};
using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
struct GemmBiasTransKernelArg
{
// pointers
const void* a_ptr_;
const void* b_ptr_;
std::array<const void*, NumDTensor> ds_ptr_;
void* e_ptr_;
index_t M_, N_, K_;
index_t StrideA_, StrideB_;
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
const ADataType* a_ptr_;
const BDataType* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
......@@ -545,7 +344,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
GroupedGemmBlock2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
......@@ -564,36 +363,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
grid_size_ = 0;
grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) ||
0 == ck::type_convert<ck::index_t>(p_As.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As || 0 != p_As.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) ||
0 == ck::type_convert<ck::index_t>(p_Bs.size())))
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) ||
0 == ck::type_convert<ck::index_t>(p_Ds.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Ds || 0 != p_Ds.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Es");
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
}
gemm_desc_kernel_arg_.reserve(group_count_);
index_t group_id = 0;
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
......@@ -604,17 +385,23 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_;
// pointer
std::array<const void*, NumDTensor> p_ds_grid;
typename GridwiseGemm::DsGridPointer p_ds_grid{};
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid[j] = static_cast<const DDataType*>(p_Ds[i][j]);
p_ds_grid(j) = static_cast<const DDataType*>(p_Ds[i][j]);
});
// tensor descriptors for problem definiton
......@@ -623,16 +410,16 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
DsGridDesc_M_N ds_grid_desc_m_n;
std::array<index_t, NumDTensor> StrideDs;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
});
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
......@@ -640,26 +427,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
std::cout << "grp id: " << group_id << " grid_size: " << grid_size_grp << std::endl;
const index_t grid_size_grp =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
.block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
local_b2c_tile_map))
block_2_etile_map))
{
// tensor descriptors for block/thread-wise copy
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -676,17 +461,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]),
static_cast<const BDataType*>(p_Bs[i]),
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<EDataType*>(p_Es[i]),
a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
......@@ -695,17 +473,16 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
local_b2c_tile_map,
block_2_etile_map,
BlockStart,
BlockEnd});
}
group_id++;
}
}
// private:
index_t group_count_;
index_t skipped_group_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
......@@ -715,8 +492,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_;
};
......@@ -729,12 +504,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
bool has_main_k_block_loop = true;
#if 1
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
#endif
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
......@@ -777,81 +546,33 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
#if 1
grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_,
arg.gemm_desc_kernel_arg_[i].ds_ptr_,
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
arg.gemm_desc_kernel_arg_[i].K_,
arg.gemm_desc_kernel_arg_[i].StrideA_,
arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_});
#endif
}
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() *
sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
GemmBiasTransKernelArg,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
const void* kernel_args_dev = nullptr;
if(arg.grouped_gemm_kernel_args_dev != nullptr)
{
kernel_args_dev = arg.grouped_gemm_kernel_args_dev;
}
else
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
if(arg.gemm_desc_kernel_arg_[i].a_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].b_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].e_ptr_ == nullptr)
{
throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
}
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
grouped_gemm_kernel_args.data(),
grouped_gemm_kernel_args.size() *
sizeof(GroupedGemmKernelArgument<NumDTensor>),
hipMemcpyHostToDevice,
stream_config.stream_id_));
kernel_args_dev = arg.p_workspace_;
}
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(kernel_args_dev),
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
......@@ -879,7 +600,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
return false;
}
......@@ -979,21 +701,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return str.str();
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ *
sizeof(GroupedGemmKernelArgument<NumDTensor>);
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
}
};
......
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap,
typename GroupedGemmBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
#endif
const index_t group_id = block_id / grid_size_grp;
if(group_id >= group_count)
return;
#if 0
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#else
const index_t M = gemm_desc_ptr[group_id].M;
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M == 0 || N == 0 || K == 0)
return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
const auto StrideE = gemm_desc_ptr[group_id].StrideE;
#if 0
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
#endif
using DsDataType = ck::Tuple<>;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp;
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n};
constexpr auto NumDTensor = 0;
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid_;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
auto m_loops = local_b2e_tile_map.CalculateMLoops();
index_t m_id = 0;
do
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, m_id);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
m_id += 1;
} while(m_id < m_loops);
#endif
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumPrefetch, // NumGemmKPrefetchStage
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__
OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start,
index_t mblock_id_off = 0)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
mblock_id_off_ = mblock_id_off;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
return make_tuple(idx_bot[Number<0>{}] + mblock_id_off_, idx_bot[Number<1>{}]);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
index_t mblock_id_off_;
};
template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_M00_N0_M01Adapt_MLoops
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(const BlockToCTileMap_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops&
operator=(BlockToCTileMap_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt_MLoops(
const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt_MLoops(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ __device__ constexpr index_t CalculateMLoops() const
{
return math::integer_divide_ceil(M_, MPerBlock_);
}
__host__ static constexpr index_t CalculateGridSize(index_t /*M*/, index_t N)
{
const auto M0 = 1; // math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = 1; // math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t M01_;
};
using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
struct GemmBiasTransKernelArg
{
// pointers
const void* a_ptr_;
const void* b_ptr_;
std::array<const void*, NumDTensor> ds_ptr_;
void* e_ptr_;
index_t M_, N_, K_;
index_t StrideA_, StrideB_;
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
{
grid_size_ = 0;
grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) ||
0 == ck::type_convert<ck::index_t>(p_As.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As || 0 != p_As.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) ||
0 == ck::type_convert<ck::index_t>(p_Bs.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) ||
0 == ck::type_convert<ck::index_t>(p_Ds.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Ds || 0 != p_Ds.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Es");
}
gemm_desc_kernel_arg_.reserve(group_count_);
index_t group_id = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_;
// pointer
std::array<const void*, NumDTensor> p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
p_ds_grid[j] = static_cast<const DDataType*>(p_Ds[i][j]);
});
// tensor descriptors for problem definiton
const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
DsGridDesc_M_N ds_grid_desc_m_n;
std::array<index_t, NumDTensor> StrideDs;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
});
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
std::cout << "grp id: " << group_id << " grid_size: " << grid_size_grp << std::endl;
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
local_b2c_tile_map))
{
// tensor descriptors for block/thread-wise copy
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
local_b2c_tile_map,
BlockStart,
BlockEnd});
}
group_id++;
}
}
// private:
index_t group_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
bool has_main_k_block_loop = true;
#if 1
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
#endif
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
<< "}";
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
#if 1
grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_,
arg.gemm_desc_kernel_arg_[i].ds_ptr_,
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
arg.gemm_desc_kernel_arg_[i].K_,
arg.gemm_desc_kernel_arg_[i].StrideA_,
arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_});
#endif
}
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
const void* kernel_args_dev = nullptr;
if(arg.grouped_gemm_kernel_args_dev != nullptr)
{
kernel_args_dev = arg.grouped_gemm_kernel_args_dev;
}
else
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
if(arg.gemm_desc_kernel_arg_[i].a_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].b_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].e_ptr_ == nullptr)
{
throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
}
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
grouped_gemm_kernel_args.data(),
grouped_gemm_kernel_args.size() *
sizeof(GroupedGemmKernelArgument<NumDTensor>),
hipMemcpyHostToDevice,
stream_config.stream_id_));
kernel_args_dev = arg.p_workspace_;
}
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
{
return false;
}
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
}
}
return supported;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
{
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemm_Xdl_Fixed_NK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
return str.str();
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ *
sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment