Commit 5f7eda63 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use default B2C (3D grid) in grid gemm v2r4r2.

parent 96b535ba
......@@ -19,12 +19,14 @@ namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation>
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg)
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
const Block2CTileMap& b2c_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
......@@ -32,7 +34,7 @@ __global__ void
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared));
karg, static_cast<void*>(p_shared), b2c_map);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
......@@ -456,15 +458,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
......@@ -478,8 +471,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
{
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
__device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
......@@ -506,9 +512,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z);
// divide block work by [KBatch, M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment