"...resnet50_tensorflow.git" did not exist on "04ce96360fa73b0d1193ba786be96cecbc1d5333"
Commit af469e6b authored by Adam Osewski's avatar Adam Osewski
Browse files

Allocate CThreadBuffer on global function level.

* Drop support for MI100.
* Make GridwiseGEMM static without members.
parent 9205784f
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#include "ck/host_utility/hip_check_error.hpp" #include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp" #include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include <ck/utility/work_scheduling.hpp> #include <ck/utility/work_scheduling.hpp>
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -70,8 +69,7 @@ __global__ void ...@@ -70,8 +69,7 @@ __global__ void
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
...@@ -101,7 +99,12 @@ __global__ void ...@@ -101,7 +99,12 @@ __global__ void
index_t gemm_tile_id_start = 0; index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp; index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm(); StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
typename GridwiseGemm::AccType,
GridwiseGemm::GetMPerXdl() * GridwiseGemm::GetNPerXdl(),
GridwiseGemm::GetCThreadBufferVectorSize(),
true>
results_buffer;
do do
{ {
...@@ -128,10 +131,8 @@ __global__ void ...@@ -128,10 +131,8 @@ __global__ void
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
using VGPRBufferT = remove_cvref_t<decltype(GridwiseGemm::GetCThreadBuffer())>;
auto results_buffer = VGPRBufferT{};
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
results_buffer.Clear(); results_buffer.Clear();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
...@@ -139,7 +140,7 @@ __global__ void ...@@ -139,7 +140,7 @@ __global__ void
do do
{ {
// just accumulate results in registers! // just accumulate results in registers!
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
a_element_op, a_element_op,
...@@ -162,7 +163,7 @@ __global__ void ...@@ -162,7 +163,7 @@ __global__ void
// if (changed group_id || next [M,N] tile) // if (changed group_id || next [M,N] tile)
if(!b2c_tile_map.IsFirstKSplitBlock()) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
gridwise_gemm.StorePartials(p_workspace, results_buffer); GridwiseGemm::StorePartials(p_workspace, results_buffer);
} }
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -177,7 +178,7 @@ __global__ void ...@@ -177,7 +178,7 @@ __global__ void
// Accumulate only when there is at least two workgroups processing splitk data-tiles // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile. // across same MN-output tile.
if(neighbour_count > 1) if(neighbour_count > 1)
gridwise_gemm.AccumulatePartials(p_workspace, results_buffer, neighbour_count); GridwiseGemm::AccumulatePartials(p_workspace, results_buffer, neighbour_count);
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -196,7 +197,7 @@ __global__ void ...@@ -196,7 +197,7 @@ __global__ void
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
gridwise_gemm.template RunWrite(p_ds_grid, GridwiseGemm::template RunWrite(p_ds_grid,
p_e_grid, p_e_grid,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
M, M,
......
...@@ -269,54 +269,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -269,54 +269,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using BBlockDesc_KBatch_BK0PerB_NPerB_BK1 = using BBlockDesc_KBatch_BK0PerB_NPerB_BK1 =
remove_cvref_t<decltype(GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1())>; remove_cvref_t<decltype(GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1())>;
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
AGridDesc_KBatch_AK0_M_AK1,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1,
ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
BGridDesc_KBatch_BK0_N_BK1,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1,
BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
public: public:
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
...@@ -664,13 +616,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -664,13 +616,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
} }
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
__device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; } __device__ __host__ static constexpr auto GetNPerBlock() { return NPerBlock; }
__device__ __host__ static constexpr auto GetMPerXdl() { return MPerXdl; }
__device__ __host__ static constexpr auto GetNPerXdl() { return NPerXdl; }
__device__ __host__ static constexpr auto& GetCThreadBuffer() __device__ static constexpr auto GetCThreadBufferVectorSize()
{ {
using BlockwiseGemmT = using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
...@@ -686,20 +637,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -686,20 +637,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
NXdlPerWave, NXdlPerWave,
KPack, KPack,
LoopSched>())>; LoopSched>())>;
BlockwiseGemmT blockwise_gemm; return BlockwiseGemmT::xdlops_gemm.GetRegSizePerXdlops();
return blockwise_gemm.GetCThreadBuffer();
} }
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf> template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
__device__ void RunGEMM(const ADataType* __restrict__ p_a_grid, __device__ static void RunGEMM(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf) CThreadBuf& c_thread_buf)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize());
...@@ -727,6 +677,54 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -727,6 +677,54 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr auto b_block_desc_kbatch_bk0_n_bk1 = constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
using ABlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ComputeType,
AGridDesc_KBatch_AK0_M_AK1,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1,
ABlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
using BBlockwiseCopy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<1, BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
ComputeType,
BGridDesc_KBatch_BK0_N_BK1,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1,
BBlockTransferSrcAccessOrder,
Sequence<2, 0, 1, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
NumGemmKPrefetchStage>;
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
ABlockwiseCopy(a_grid_desc_kbatch_ak0_m_ak1, ABlockwiseCopy(a_grid_desc_kbatch_ak0_m_ak1,
...@@ -817,19 +815,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -817,19 +815,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf> template <bool HasMainKBlockLoop, typename Block2ETileMap, typename CThreadBuf>
__device__ void RunGEMM(const void* __restrict__ p_a_grid_, __device__ static void RunGEMM(const void* __restrict__ p_a_grid_,
const void* __restrict__ p_b_grid_, const void* __restrict__ p_b_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const index_t M, const index_t M,
const index_t N, const index_t N,
const index_t K, const index_t K,
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
const index_t KBatch, const index_t KBatch,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
CThreadBuf& c_thread_buf) CThreadBuf& c_thread_buf)
{ {
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_); const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_); const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
...@@ -854,7 +852,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -854,7 +852,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// TODO Need to do CShuffle already here: // TODO Need to do CShuffle already here:
template <typename CThreadBuf> template <typename CThreadBuf>
__device__ void StorePartials(void* __restrict__ p_workspace, const CThreadBuf& c_thread_buf) __device__ static void StorePartials(void* __restrict__ p_workspace,
const CThreadBuf& c_thread_buf)
{ {
// M0 = grid_size // M0 = grid_size
// N0 = 1 // N0 = 1
...@@ -999,9 +998,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -999,9 +998,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
template <typename CThreadBuf> template <typename CThreadBuf>
__device__ void AccumulatePartials(void* __restrict__ p_workspace, __device__ static void AccumulatePartials(void* __restrict__ p_workspace,
CThreadBuf& c_thread_buf, CThreadBuf& c_thread_buf,
uint32_t reduce_count) uint32_t reduce_count)
{ {
using BlockwiseGemmT = using BlockwiseGemmT =
remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< remove_cvref_t<decltype(BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
...@@ -1167,16 +1166,16 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1167,16 +1166,16 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
} }
template <typename Block2ETileMap, typename CThreadBuf> template <typename Block2ETileMap, typename CThreadBuf>
__device__ void RunWrite(DsGridPointer p_ds_grid, __device__ static void RunWrite(DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const index_t M, const index_t M,
const index_t N, const index_t N,
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE, const index_t StrideE,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
const Block2ETileMap& block_2_etile_map, const Block2ETileMap& block_2_etile_map,
const CThreadBuf& c_thread_buf) const CThreadBuf& c_thread_buf)
{ {
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment