Commit e114409b authored by Adam Osewski's avatar Adam Osewski
Browse files

DeviceOp + GridwiseGemm Draft GroupedGEMM+SplitK+TileLoop

* First Part: accumulation across tiles in CThreadBuffer
parent 8febe0f7
......@@ -24,7 +24,11 @@ add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
add_example_executable(example_grouped_gemm_multiple_d_splitk_xdl_fp16 grouped_gemm_multiple_d_splitk_xdl_fp16.cpp)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_splitk_xdl_fp16)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| AThreadTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BThreadTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcReset| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcReset| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| CoordinateAfter| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| CoordinateAfter| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Run| | | | | | | | Run| | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, false, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, false, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
std::vector<const void*> p_As;
std::vector<const void*> p_Bs;
gemm_descs.reserve(group_count);
p_As.reserve(group_count);
p_Bs.reserve(group_count);
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck::utils::FillConstant<ADataType>{1.f}(a_tensors[i]);
ck::utils::FillConstant<BDataType>{1.f}(b_tensors[i]);
}
}
using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmMultipleDKernelArguments<>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero();
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{}});
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatchSize(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
return pass;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
std::vector<ck::index_t> Ms{64};
problem_size.group_count = Ms.size();
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(Ms[i]);
problem_size.Ns.push_back(128);
problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (> 0)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmMultipleDKernelArguments
{
__host__ __device__
GroupedGemmMultipleDKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include <ck/utility/work_scheduling.hpp>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] p_workspace Pointer to the auxilliary workgroup workspace used to store
/// partial results.
/// @param[in] tile_count The overall number of output tiles we divided all groups into.
/// @param[in] k_batch The number of batches we split the K dimension into.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for grouped gemm calculation and work
/// distribution.
/// @tparam FloatA Input tensor A elements' data type.
/// @tparam FloatB Input tensor B elements' data type.
/// @tparam FloatC Input tensor C elements' data type.
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
///
template <typename GridwiseGemm,
typename GemmDesc,
typename FloatA,
typename FloatB,
typename FloatC,
typename Block2ETileMapKSplit,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk_v2(
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
void* const __restrict__ p_workspace,
const index_t tile_count,
const index_t k_batch,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
[[maybe_unused]] const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
uint32_t* const __restrict__ p_flags = reinterpret_cast<uint32_t* const __restrict__>(
reinterpret_cast<char*>(p_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(typename GridwiseGemm::AccType)));
StridedReductionTileLoop work_scheduler{tile_count, p_flags};
// early exit if no work.
if(work_scheduler.tile_id_ >= tile_count)
return;
if(get_thread_global_1d_id() < work_scheduler.GetFlagCount(k_batch))
p_flags[get_thread_global_1d_id()] = 0;
index_t group_id = 0;
index_t offset = 0;
auto M = gemm_desc_ptr[group_id].M;
auto N = gemm_desc_ptr[group_id].N;
auto b2c_tile_map = Block2ETileMapKSplit(M, N, k_batch);
index_t grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp;
do
{
// Find corresponding GEMM group for our tile
while(!(work_scheduler.tile_id_ >= gemm_tile_id_start &&
work_scheduler.tile_id_ < gemm_tile_id_end))
{
offset += grid_size_grp;
group_id++;
M = gemm_desc_ptr[group_id].M;
N = gemm_desc_ptr[group_id].N;
b2c_tile_map = Block2ETileMapKSplit(M, N, k_batch);
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
gemm_tile_id_start = offset;
gemm_tile_id_end = offset + grid_size_grp;
}
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
// const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
// const auto StrideC = gemm_desc_ptr[group_id].StrideC;
auto gridwise_gemm = GridwiseGemm();
auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
do
{
// just accumulate results in registers!
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
// Store partial results to auxilliary workspace.
// make results buffer tensor descriptor (registers).
// make workspace gmem tensor descriptor
// create ThreadGroupTransform and run copy
// if (threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
}
const index_t output_tile_idx =
__builtin_amdgcn_readfirstlane(b2c_tile_map.GetOutputTileIdx());
const index_t output_tile_idx_offset = __builtin_amdgcn_readfirstlane(offset / k_batch);
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock())
{
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
[[maybe_unused]] const index_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
// if(threadIdx.x == 0)
// {
// // using CThreadBuffer = decltype(results_buffer);
// // constexpr index_t n_scalars = CThreadBuffer::s_per_buf.value;
// constexpr index_t n_scalars = 4;
// static_for<0, n_scalars, 1>{}([&](auto i) {
// printf("[kernel] bid: %d; c_thread_buff[%d]: %f\n",
// static_cast<index_t>(blockIdx.x),
// i.value,
// static_cast<float>(results_buffer[i]));
// });
// }
// TODO: do blockwise reduction from workspace (GMEM) to results_buffer (registers)
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// TODO do fusion, cshuffle and store results to GMEM
// gridwise_gemm.RunWrite(results_buffer,
// p_c_grid,
// M,
// N,
// K,
// StrideA,
// StrideB,
// StrideC,
// MPadded,
// NPadded,
// KPadded,
// K0,
// k_batch,
// static_cast<void*>(p_shared),
// b2c_tile_map);
}
else
{
// TODO: double buffering in order to not wait for this.
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
}
} while(work_scheduler.HasTile());
#else
ignore = gemm_descs_const;
ignore = p_workspace;
ignore = tile_count;
ignore = k_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
: public DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffle;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2<
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
using KernelArguments = GroupedGemmMultipleDKernelArguments<NumDTensor>;
using Block2ETileMapKSplit = BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock>;
static constexpr index_t DefaultKBatch = 1;
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
int occupancy_num_blocks,
int gpu_cu_count)
: Argument(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_element_op,
b_element_op,
cde_element_op,
DefaultKBatch,
occupancy_num_blocks,
gpu_cu_count)
{
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t kbatch,
int occupancy_num_blocks,
int gpu_cu_count)
: K_BATCH{kbatch},
group_count_{0},
skipped_group_count_{0},
tile_count_{0},
occupancy_num_blocks_{occupancy_num_blocks},
gpu_cu_count_{gpu_cu_count},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size");
}
gemm_kernel_args_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M * N * K == 0)
{
skipped_group_count_++;
continue;
}
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_e = gemm_descs[i].stride_C_;
auto b2c_tile_map = Block2ETileMapKSplit{M, N, K_BATCH};
const index_t grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
tile_count_ += grid_size_grp;
std::array<index_t, NumDTensor> stride_ds;
static_for<0, NumDTensor, 1>{}([&](auto j) {
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
stride_ds[j] = gemm_descs[i].stride_Ds_[j];
});
gemm_kernel_args_.emplace_back(type_convert<const ADataType*>(p_As[i]),
type_convert<const BDataType*>(p_Bs[i]),
p_Ds[i],
type_convert<EDataType*>(p_Es[i]),
M,
N,
K,
stride_a,
stride_b,
stride_ds,
stride_e);
}
}
/**
* @brief Set new kbatch value.
*
* @param[in] kbatch The new splitK parameter value.
*/
void UpdateKBatch(index_t kbatch)
{
K_BATCH = kbatch;
tile_count_ = 0;
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
const auto b2c_tile_map = Block2ETileMapKSplit{gemm_arg.M, gemm_arg.N, K_BATCH};
const index_t grid_size_grp =
b2c_tile_map.CalculateGridSize(gemm_arg.M, gemm_arg.N);
tile_count_ += grid_size_grp;
}
}
// private:
index_t K_BATCH;
index_t group_count_;
index_t skipped_group_count_;
// The overall number of output tiles to be processed.
index_t tile_count_;
const void* p_dev_gemm_args_;
int occupancy_num_blocks_;
int gpu_cu_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
std::vector<KernelArguments> gemm_kernel_args_;
};
struct KernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
static constexpr int CU_SIMDS = 4;
// Assume we want to have at most 2 waves per SIMD
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
};
// Invoker
struct Invoker : public BaseInvoker
{
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host
/// memory).
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
/// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary
/// workspace.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(dev_gemm_workspace == nullptr)
{
std::ostringstream err;
err << "The gemm workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time = 0;
if(all_have_main_k_block_loop)
{
ave_time =
DispatchKernel<true>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
else
{
ave_time =
DispatchKernel<false>(arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
return ave_time;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device buffers (for kernel arguments and
/// for kernel auxiliary workspace) provided with an argument. The user should
/// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see
/// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly
/// allocate those buffers.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_dev_gemm_args_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(arg.p_workspace_ == nullptr)
{
std::ostringstream err;
err << "The gemm workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{
bool all_have_kbatch_gt_one, all_have_main_k_block_loop;
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
all_have_kbatch_gt_one = arg.K_BATCH > 1;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
}
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = arg.gemm_kernel_args_[i];
if(stream_config.log_level_ > 0)
{
gemm_arg.Print();
}
// Currently all groups use same kbatch value.
auto kbatch = arg.K_BATCH;
if(!GridwiseGemm::CheckValidity(gemm_arg.M,
gemm_arg.N,
gemm_arg.K,
gemm_arg.StrideA,
gemm_arg.StrideB,
gemm_arg.StrideDs,
gemm_arg.StrideE,
kbatch))
{
std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
arg.gemm_kernel_args_[0].M,
arg.gemm_kernel_args_[0].K,
arg.gemm_kernel_args_[0].StrideA,
arg.K_BATCH);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << kbatch
<< ", group [0], kbatch: " << arg.K_BATCH << " in " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop);
}
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
HasMainKBlockLoop>;
return LaunchKernel(kernel, arg, dev_gemm_args, dev_gemm_workspace, stream_config);
}
template <typename KernelFunction>
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
const StreamConfig& stream_config) const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int occ_num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
int cu_count = getAvailableComputeUnitCount(stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
<< ", available CUs count: " << cu_count << ", occup. grid size: "
<< ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count
<< std::endl;
}
return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS);
}
template <typename KernelFunction>
float LaunchKernel(const KernelFunction& kernel,
const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
int max_occupancy_grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
// We launch the smaller number of workgroups from acutally needed tiles and the
// number of workgroups that maximize the GPU occupancy. That is because for some tile
// configuration the first is smaller than the latter. Launching too many workgroups
// mean some of them will have to iterate through all gemm problem descriptors just to
// find out they have nothing to do which is of course waste of GPU cycles.
if(stream_config.log_level_ > 0)
{
const index_t grid_size = ck::math::min(arg.tile_count_, max_occupancy_grid_size);
const index_t tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
std::cout << "tile_count: " << arg.tile_count_
<< ", tiles_per_block: " << tiles_per_block << std::endl;
}
return launch_and_time_kernel(
stream_config,
kernel,
dim3(ck::math::min(arg.tile_count_, max_occupancy_grid_size)),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
dev_gemm_workspace,
arg.tile_count_,
arg.K_BATCH,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
return false;
}
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = arg.gemm_kernel_args_[i];
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg.M,
gemm_arg.N,
gemm_arg.K,
gemm_arg.StrideA,
gemm_arg.StrideB,
gemm_arg.StrideDs,
gemm_arg.StrideE,
arg.K_BATCH);
if(not group_arg_valid)
{
#if DEBUG_LOG
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print();
#endif // DEBUG_LOG
}
supported = supported && group_arg_valid;
}
return supported;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return Argument{p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) override
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemm_XdlSplitKTileLoop"
<< "<"
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << ","
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
return str.str();
}
static void SetDeviceKernelArgs(Argument& arg, const void* p_dev_kernel_args)
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
int occ_grid_size =
arg.gpu_cu_count_ * std::min(arg.occupancy_num_blocks_, KernelConfig::CU_BLOCKS);
int grid_size = std::min(arg.tile_count_, occ_grid_size);
int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// stream used to launch kernel.
size_t size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(AccDataType), grid_size) +
flag_count * sizeof(uint32_t);
return size_bytes;
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
}
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(KernelArguments);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ADataType,
typename BDataType,
typename ComputeType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer>
class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
static constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
public:
using AccType = AccDataType;
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch)
{
return math::integer_least_multiple(K, KPerBlock * K_Batch);
}
__host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
AK1,
I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
BK1,
I1));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static auto
MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
const auto MPad = CalculateMPadded(M);
const auto KPad = CalculateKPadded(K, KBatch);
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto AK0 = KPad / (KBatch * AK1);
if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto
MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
const auto NPad = CalculateNPadded(N);
const auto KPad = CalculateKPadded(K, KBatch);
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto BK0 = KPad / (KBatch * BK1);
if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
private:
using AGridDesc_KBatch_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_KBatch_AK0_M_AK1(1, 1, 1, 1))>;
using BGridDesc_KBatch_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_KBatch_BK0_N_BK1(1, 1, 1, 1))>;
using ABlockDesc_KBatch_AK0PerB_MPerB_AK1 =
remove_cvref_t<decltype(GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1())>;
using BBlockDesc_KBatch_BK0PerB_NPerB_BK1 =
remove_cvref_t<decltype(GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1())>;
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
AGridDesc_KBatch_AK0_M_AK1,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1,
ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
BGridDesc_KBatch_BK0_N_BK1,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1,
BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
ComputeType,
ComputeType,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>())>;
BlockwiseGemmT blockwise_gemm_{};
public:
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ComputeType),
c_block_size * sizeof(CShuffleDataType));
}
// E desc for destination in blockwise copy
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy
template <typename DsGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
__host__ __device__ static constexpr bool
CheckValidity(const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch)
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch);
ignore = StrideDs;
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
// check gridwise gemm pipeline
const auto num_k_loop = (a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
template <typename TensorDataLayout>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, TensorDataLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, TensorDataLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ constexpr auto& GetCThreadBuffer()
{
return blockwise_gemm_.GetCThreadBuffer();
}
template <bool HasMainKBlockLoop, typename Block2ETileMap>
__device__ void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared,
[[maybe_unused]] const index_t KBatch,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize());
// divide block work by [M, N, K]
const auto block_work_idx = block_2_etile_map.GetBottomIndex();
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_kbatch_ak0_m_ak1 =
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ABlockwiseCopy(a_grid_desc_kbatch_ak0_m_ak1,
make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_kbatch_ak0_m_ak1,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
BBlockwiseCopy(b_grid_desc_kbatch_bk0_n_bk1,
make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_kbatch_bk0_n_bk1,
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop =
__builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_kbatch_bk0_n_bk1,
b_block_desc_kbatch_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm_,
c_thread_buf,
num_k_block_main_loop);
}
template <bool HasMainKBlockLoop, typename Block2ETileMap>
__device__ void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const index_t KBatch,
const Block2ETileMap& block_2_etile_map)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch);
RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_shared,
KBatch,
a_element_op,
b_element_op,
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
block_2_etile_map);
}
// template <typename CThreadBufer,
// InMemoryDataOperationEnum EGlobalMemoryDataOperation,
// index_t NumDTensor_,
// typename DsDataType_,
// typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// typename CDEElementwiseOperation_,
// typename Block2ETileMap>
// __device__ void RunWrite(CThreadBufer c_thread_buf,
// const EDataType* __restrict__ p_workspace,
// DsGridPointer p_ds_grid,
// EDataType* __restrict__ p_e_grid,
// void* __restrict__ p_shared,
// const index_t KBatch,
// const CDEElementwiseOperation_& cde_element_op,
// const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// const Block2ETileMap& block_2_etile_map)
// {
// using DsGridDesc_M_N =
// remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
// DsGridDesc_M_N ds_grid_desc_m_n;
// const auto ds_grid_buf = generate_tuple(
// [&](auto i) {
// return make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_ds_grid[i],
// ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
// },
// Number<NumDTensor_>{});
// auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// static_for<0, NumDTensor, 1>{}([&](auto j) {
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
// ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[j]);
// });
// const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
// // using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
// // remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
// // DsGridDesc_M_N{}))>;
// // DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
// ds_grid_desc_mblock_mperblock_nblock_nperblock;
// // static_for<0, NumDTensor, 1>{}([&](auto j) {
// // ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
// // });
// // const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// // MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
// // shuffle C and write out
// static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
// NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
// "wrong!");
// constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
// constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// // TODO: hacky, fix it!
// constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// // TODO: hacky, fix it!
// // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
// blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
// constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
// constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
// constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
// constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
// constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
// constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
// constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
// constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
// GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// static_cast<CShuffleDataType*>(p_shared),
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
// c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
// make_tuple(
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
// M1, // M1 = MWave
// M2, // M2 * M3 * M4 = MPerXdl
// M3,
// M4)),
// make_freeze_transform(I0),
// make_unmerge_transform(make_tuple(
// Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
// N1, // N1 = NWave
// N2))), // N2 = NPerXdl
// make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
// make_tuple(
// Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// // calculate origin of thread output tensor on global memory
// // blockwise GEMM c matrix starting index
// const auto c_thread_mtx_on_block =
// blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
// const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
// const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
// const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
// make_tuple(Sequence<0, 1, 2, 3, 4>{}),
// make_tuple(Sequence<0>{}));
// const auto m_thread_data_on_block_idx =
// m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
// make_multi_index(m_thread_data_on_block));
// const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
// make_single_stage_tensor_adaptor(
// make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
// make_tuple(Sequence<0, 1, 2>{}),
// make_tuple(Sequence<0>{}));
// const auto n_thread_data_on_block_idx =
// n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
// make_multi_index(n_thread_data_on_block));
// // shuffle: threadwise copy C from VGPR to LDS
// auto c_thread_copy_vgpr_to_lds =
// ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
// CShuffleDataType,
// decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
// ck::tensor_operation::element_wise::PassThrough,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// I1,
// I1,
// M2,
// I1,
// M4,
// I1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// 7,
// 1,
// InMemoryDataOperationEnum::Set,
// 1,
// true>{
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// make_multi_index(0,
// 0,
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]),
// ck::tensor_operation::element_wise::PassThrough{}};
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_desc_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
// Number<NumDTensor_>{}));
// // tuple of reference to C/Ds tensor descriptors
// const auto c_ds_buf_refs = concat_tuple_of_reference(
// tie(c_shuffle_block_buf),
// generate_tie(
// [&](auto i) -> const auto& // return type should be reference
// { return ds_grid_buf[i]; },
// Number<NumDTensor_>{}));
// // tuple of starting index of C/Ds blockwise copy
// const auto idx_c_ds_block_begin = container_concat(
// make_tuple(make_multi_index(0, 0, 0, 0)),
// generate_tuple(
// [&](auto) {
// return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
// },
// Number<NumDTensor_>{}));
// // space filling curve for threadwise C in VGPR before shuffle
// constexpr auto sfc_c_vgpr =
// SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
// Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
// Sequence<CShuffleMXdlPerWavePerShuffle,
// CShuffleNXdlPerWavePerShuffle,
// 1,
// 1,
// M2,
// 1,
// M4,
// 1>>{};
// // space filling curve for shuffled blockwise C/D/E
// constexpr auto sfc_cde_block =
// SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
// Sequence<0, 2, 1, 3>,
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
// constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
// // blockwise copy C/D/E between LDS and global
// auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
// ThisThreadBlock,
// decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
// Tuple<EDataType>,
// decltype(c_ds_desc_refs),
// decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
// CDEElementwiseOperation_,
// Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// // Sequence support
// // arbitray type
// Sequence<1,
// CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
// 1,
// CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
// CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
// Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
// Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
// 3, // index_t VectorDim,
// CDEShuffleBlockTransferScalarPerVector_NPerBlock,
// sequence_merge_t<
// Sequence<true>,
// uniform_sequence_gen_t<NumDTensor_,
// false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
// Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
// {c_ds_desc_refs,
// idx_c_ds_block_begin,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
// cde_element_op};
// static_for<0, num_access, 1>{}([&](auto access_id) {
// // make sure it's safe to write to LDS
// block_sync_lds();
// // each thread write its data from VGPR to LDS
// c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
// c_thread_buf,
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
// c_shuffle_block_buf);
// // make sure it's safe to read from LDS
// block_sync_lds();
// // each block copy its data from LDS to global
// cde_block_copy_lds_and_global.Run(
// c_ds_desc_refs,
// c_ds_buf_refs,
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// tie(e_grid_buf));
// if constexpr(access_id < num_access - 1)
// {
// constexpr auto cde_lds_and_global_step =
// sfc_cde_block.GetForwardStep(access_id);
// // move on Ds
// static_for<0, NumDTensor_, 1>{}([&](auto i) {
// cde_block_copy_lds_and_global.MoveSrcSliceWindow(
// c_ds_desc_refs, i + I1, cde_lds_and_global_step);
// });
// // move on E
// cde_block_copy_lds_and_global.MoveDstSliceWindow(
// tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
// I0,
// cde_lds_and_global_step);
// }
// });
// }
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment