Commit 5f7eda63 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use default B2C (3D grid) in grid gemm v2r4r2.

parent 96b535ba
...@@ -19,12 +19,14 @@ namespace ck { ...@@ -19,12 +19,14 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg) kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
const Block2CTileMap& b2c_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
...@@ -32,7 +34,7 @@ __global__ void ...@@ -32,7 +34,7 @@ __global__ void
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
karg, static_cast<void*>(p_shared)); karg, static_cast<void*>(p_shared), b2c_map);
#else #else
ignore = karg; ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -456,15 +458,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -456,15 +458,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
} }
// return block_id to C matrix tile idx (m0, n0) mapping
template <typename CGridDesc>
__host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
{
return BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc>(
c_m_n_grid_desc, 8, KBatch);
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
...@@ -478,8 +471,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -478,8 +471,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{})); Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
} }
template <bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation> // return block_id to C matrix tile idx (m0, n0, k_split) mapping
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block) __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
{
return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
}
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap>
__device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map)
{ {
const FloatAB* p_a_grid = karg.p_a_grid; const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid; const FloatAB* p_b_grid = karg.p_b_grid;
...@@ -506,9 +512,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -506,9 +512,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); // divide block work by [KBatch, M, N]
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x); const auto block_work_idx =
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment