Commit af469e6b authored by Adam Osewski's avatar Adam Osewski
Browse files

Allocate CThreadBuffer on global function level.

* Drop support for MI100.
* Make GridwiseGEMM static without members.
parent 9205784f
......@@ -12,7 +12,6 @@
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include <ck/utility/work_scheduling.hpp>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
......@@ -70,8 +69,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
......@@ -101,7 +99,12 @@ __global__ void
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm();
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::AccType,
GridwiseGemm::GetMPerXdl() * GridwiseGemm::GetNPerXdl(),
GridwiseGemm::GetCThreadBufferVectorSize(),
true>
results_buffer;
do
{
......@@ -128,10 +131,8 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
using VGPRBufferT = remove_cvref_t<decltype(GridwiseGemm::GetCThreadBuffer())>;
auto results_buffer = VGPRBufferT{};
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
results_buffer.Clear();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
......@@ -139,7 +140,7 @@ __global__ void
do
{
// just accumulate results in registers!
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
......@@ -162,7 +163,7 @@ __global__ void
// if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock())
{
gridwise_gemm.StorePartials(p_workspace, results_buffer);
GridwiseGemm::StorePartials(p_workspace, results_buffer);
}
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
......@@ -177,7 +178,7 @@ __global__ void
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(neighbour_count > 1)
gridwise_gemm.AccumulatePartials(p_workspace, results_buffer, neighbour_count);
GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count);
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
......@@ -196,7 +197,7 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
gridwise_gemm.template RunWrite(p_ds_grid,
GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
......
......@@ -269,54 +269,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using BBlockDesc_KBatch_BK0PerB_NPerB_BK1 =
remove_cvref_t<decltype(GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1())>;
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
AGridDesc_KBatch_AK0_M_AK1,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1,
ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
BGridDesc_KBatch_BK0_N_BK1,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1,
BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
public:
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
......@@ -664,13 +616,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; }
__device__ __host__ static constexpr auto GetMPerXdl() { return MPerXdl; }
__device__ __host__ static constexpr auto GetNPerXdl() { return NPerXdl; }
__device__ __host__ static constexpr auto& GetCThreadBuffer()
__device__ static constexpr auto GetCThreadBufferVectorSize()
{
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
......@@ -686,20 +637,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
NXdlPerWave,
KPack,
LoopSched>())>;
BlockwiseGemmT blockwise_gemm;
return blockwise_gemm.GetCThreadBuffer();
return BlockwiseGemmT::xdlops_gemm.GetRegSizePerXdlops();
}
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
__device__ void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf)
__device__ static void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
......@@ -727,6 +677,54 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
AGridDesc_KBatch_AK0_M_AK1,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1,
ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
BGridDesc_KBatch_BK0_N_BK1,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1,
BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
// A matrix blockwise copy
auto a_blockwise_copy =
ABlockwiseCopy(a_grid_desc_kbatch_ak0_m_ak1,
......@@ -817,19 +815,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
__device__ void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const index_t KBatch,
const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf)
__device__ static void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const index_t KBatch,
const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf)
{
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
......@@ -854,7 +852,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// TODO Need to do CShuffle already here:
template <typename CThreadBuf>
__device__ void StorePartials(void* __restrict__ p_workspace, const CThreadBuf& c_thread_buf)
__device__ static void StorePartials(void* __restrict__ p_workspace,
const CThreadBuf& c_thread_buf)
{
// M0 = grid_size
// N0 = 1
......@@ -999,9 +998,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
template <typename CThreadBuf>
__device__ void AccumulatePartials(void* __restrict__ p_workspace,
CThreadBuf& c_thread_buf,
uint32_t reduce_count)
__device__ static void AccumulatePartials(void* __restrict__ p_workspace,
CThreadBuf& c_thread_buf,
uint32_t reduce_count)
{
using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
......@@ -1167,16 +1166,16 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
template <typename Block2ETileMap, typename CThreadBuf>
__device__ void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const index_t M,
const index_t N,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map,
const CThreadBuf& c_thread_buf)
__device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const index_t M,
const index_t N,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map,
const CThreadBuf& c_thread_buf)
{
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment