Commit b2a49620 authored by carlushuang's avatar carlushuang
Browse files

shrink karg for streamk

parent fcb2911e
...@@ -143,8 +143,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -143,8 +143,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
// TODO: remove clear buffer for streamk kernels // TODO: remove clear buffer for streamk kernels
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time = ave_time = launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config, kernel, grid_dims, dim3(BlockSize), 0, karg); kernel,
grid_dims,
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.block_mapping);
return ave_time; return ave_time;
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits> #include <limits>
#include <stdlib.h>
namespace ck { namespace ck {
...@@ -635,14 +636,18 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -635,14 +636,18 @@ struct BlockToCTileMap_3DGrid_KSplit
return true; return true;
} }
}; };
#include <stdlib.h>
template <uint32_t MPerBlock_, uint32_t NPerBlock_, uint32_t KPerBlock_> template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK struct BlockToCTileMap_GemmStreamK
{ {
static constexpr uint32_t min_k_iters_per_sk_block = 2; static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_; static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_; static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_; static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//-------------------------------------- //--------------------------------------
// pass to device // pass to device
...@@ -657,8 +662,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -657,8 +662,8 @@ struct BlockToCTileMap_GemmStreamK
uint32_t k_iters_per_big_block; uint32_t k_iters_per_big_block;
MDiv k_iters_per_tile; MDiv k_iters_per_tile;
MDiv n_tiles; MDiv n_tiles;
MDiv tile_swizzle_sub_m;
MDiv tile_swizzle_sub_m_rem; // MDiv tile_swizzle_sub_m_rem;
//-------------------------------------- //--------------------------------------
static int env_get_int(const char* var_name, int default_int) static int env_get_int(const char* var_name, int default_int)
...@@ -676,8 +681,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -676,8 +681,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t k, uint32_t k,
uint32_t num_cu, uint32_t num_cu,
uint32_t occupancy, uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff, uint32_t sk_blocks = 0xffffffff)
uint32_t tile_swizzle_sub_m_factor = 8)
{ {
uint32_t num_tiles = uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
...@@ -723,16 +727,9 @@ struct BlockToCTileMap_GemmStreamK ...@@ -723,16 +727,9 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles = partial_dispatche_tiles + num_cu; sk_tiles = partial_dispatche_tiles + num_cu;
} }
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block = k_iters_per_tile.get(); dp_iters_per_block = k_iters_per_tile.get();
sk_total_iters = k_iters_per_tile.get() * sk_tiles; sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{ {
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles = uint32_t max_sk_tiles =
...@@ -812,11 +809,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -812,11 +809,8 @@ struct BlockToCTileMap_GemmStreamK
} }
n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock)); n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock));
tile_swizzle_sub_m_factor = // tile_swizzle_sub_m_rem =
env_get_int("tile_swizzle_sub_m_factor", tile_swizzle_sub_m_factor); // MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
tile_swizzle_sub_m = MDiv(tile_swizzle_sub_m_factor);
tile_swizzle_sub_m_rem =
MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m_factor);
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, " "sk_num_blocks:%d, "
...@@ -896,22 +890,25 @@ struct BlockToCTileMap_GemmStreamK ...@@ -896,22 +890,25 @@ struct BlockToCTileMap_GemmStreamK
// swizzle tile // swizzle tile
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
// uint32_t n_tiles = math::integer_divide_ceil(n, NPerBlock);
uint32_t quo_sub_m, rem_sub_m; uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
tile_swizzle_sub_m.divmod(m_tile_idx, quo_sub_m, rem_sub_m);
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem.get())) const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
? tile_swizzle_sub_m ? tile_swizzle_sub_m
: tile_swizzle_sub_m_rem; : tile_swizzle_sub_m_rem;
uint32_t m_tile_idx_sub0, m_tile_idx_sub1; uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
tile_swizzle_sub_m.divmod(m_tile_idx, m_tile_idx_sub0, m_tile_idx_sub1); m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles.get(); uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles.get();
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
sub_m_adapt.divmod(tile_idx_local, n_tile_idx_with_adapt, m_tile_idx_with_adapt);
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m.get(), n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
// sub_m_adapt.divmod(tile_idx_local, n_tile_idx_with_adapt, m_tile_idx_with_adapt);
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt); n_tile_idx_with_adapt);
} }
}; };
......
...@@ -24,14 +24,36 @@ __global__ void ...@@ -24,14 +24,36 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg) // kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
GridwiseGemm::Run(karg, static_cast<void*>(p_shared)); // GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
block_mapping,
static_cast<void*>(p_shared));
#else #else
ignore = karg; ignore = karg;
ignore = b2c_map; ignore = b2c_map;
...@@ -40,9 +62,9 @@ __global__ void ...@@ -40,9 +62,9 @@ __global__ void
template <index_t BlockSize, template <index_t BlockSize,
typename Block2CTileMap_, typename Block2CTileMap_,
typename FloatAB, typename FloatAB_,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC_,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
...@@ -95,8 +117,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -95,8 +117,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static constexpr auto KPerBlock = K0PerBlock * K1; static constexpr auto KPerBlock = K0PerBlock * K1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using FloatCShuffle = FloatAcc;
using Block2CTileMap = Block2CTileMap_; using Block2CTileMap = Block2CTileMap_;
using FloatAB = FloatAB_;
using FloatC = FloatC_;
struct Argument : public ck::tensor_operation::device::BaseArgument struct Argument : public ck::tensor_operation::device::BaseArgument
{ {
...@@ -154,31 +179,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -154,31 +179,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math::integer_divide_ceil(karg.M, MPerBlock), math::integer_divide_ceil(karg.M, MPerBlock),
karg.k_batch); karg.k_batch);
} }
#if 0
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
#endif
__host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; } __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
...@@ -296,7 +296,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -296,7 +296,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB), sizeof(FloatAB),
c_block_size * sizeof(FloatC)); c_block_size * sizeof(FloatCShuffle));
} }
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
...@@ -384,30 +384,29 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -384,30 +384,29 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{})); Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
} }
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
// __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
// {
// return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
// }
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>; using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
// using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block) __device__ static void Run(const FloatAB* p_a_grid,
{ const FloatAB* p_b_grid,
const FloatAB* p_a_grid = karg.p_a_grid; FloatC* p_c_grid,
const FloatAB* p_b_grid = karg.p_b_grid; index_t M,
FloatC* p_c_grid = karg.p_c_grid; index_t N,
index_t K,
uint32_t m = karg.M; index_t StrideA,
uint32_t n = karg.N; index_t StrideB,
uint32_t k = karg.K; index_t StrideC,
Block2CTileMap block_mapping,
void* __restrict__ p_shared_block)
{
uint32_t m = M;
uint32_t n = N;
uint32_t k = K;
uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock; uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock; uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock; uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
uint32_t stride_a = karg.StrideA; uint32_t stride_a = StrideA;
uint32_t stride_b = karg.StrideB; uint32_t stride_b = StrideB;
uint32_t stride_c = karg.StrideC; uint32_t stride_c = StrideC;
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a); const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b); const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
...@@ -467,7 +466,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -467,7 +466,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
auto& block_mapping = karg.block_mapping;
uint32_t block_idx = block_mapping.get_block_idx(); uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks; bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx; bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx;
...@@ -608,7 +606,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -608,7 +606,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block), static_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
...@@ -662,7 +660,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -662,7 +660,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// VGPR to LDS // VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, FloatAcc,
FloatC, FloatCShuffle,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -701,7 +699,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -701,7 +699,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData, FloatCShuffle, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
...@@ -807,10 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -807,10 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure next loop LDS is ready for use // make sure next loop LDS is ready for use
block_sync_lds(); block_sync_lds();
} }
// if(threadIdx.x == 0)
// printf("xxx bid:%d, xx_total_iter_length:%d \n", static_cast<int>(blockIdx.x),
// xx_total_iter_length);
} }
template <typename Layout> template <typename Layout>
......
...@@ -162,7 +162,7 @@ struct MDiv ...@@ -162,7 +162,7 @@ struct MDiv
// 1 dword -> 3 dword storage // 1 dword -> 3 dword storage
uint32_t divisor; uint32_t divisor;
uint32_t multiplier; uint32_t multiplier;
uint32_t shift; uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host // prefer construct on host
__host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_) __host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment