Commit b2a49620 authored by carlushuang's avatar carlushuang
Browse files

shrink karg for streamk

parent fcb2911e
......@@ -143,8 +143,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
// TODO: remove clear buffer for streamk kernels
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time =
launch_and_time_kernel(stream_config, kernel, grid_dims, dim3(BlockSize), 0, karg);
ave_time = launch_and_time_kernel(stream_config,
kernel,
grid_dims,
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
karg.block_mapping);
return ave_time;
}
......
......@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
#include <stdlib.h>
namespace ck {
......@@ -635,14 +636,18 @@ struct BlockToCTileMap_3DGrid_KSplit
return true;
}
};
#include <stdlib.h>
template <uint32_t MPerBlock_, uint32_t NPerBlock_, uint32_t KPerBlock_>
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
......@@ -657,8 +662,8 @@ struct BlockToCTileMap_GemmStreamK
uint32_t k_iters_per_big_block;
MDiv k_iters_per_tile;
MDiv n_tiles;
MDiv tile_swizzle_sub_m;
MDiv tile_swizzle_sub_m_rem;
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
static int env_get_int(const char* var_name, int default_int)
......@@ -676,8 +681,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff,
uint32_t tile_swizzle_sub_m_factor = 8)
uint32_t sk_blocks = 0xffffffff)
{
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
......@@ -723,15 +727,8 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles = partial_dispatche_tiles + num_cu;
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block = k_iters_per_tile.get();
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
......@@ -812,11 +809,8 @@ struct BlockToCTileMap_GemmStreamK
}
n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock));
tile_swizzle_sub_m_factor =
env_get_int("tile_swizzle_sub_m_factor", tile_swizzle_sub_m_factor);
tile_swizzle_sub_m = MDiv(tile_swizzle_sub_m_factor);
tile_swizzle_sub_m_rem =
MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m_factor);
// tile_swizzle_sub_m_rem =
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
......@@ -896,22 +890,25 @@ struct BlockToCTileMap_GemmStreamK
// swizzle tile
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
// uint32_t n_tiles = math::integer_divide_ceil(n, NPerBlock);
uint32_t quo_sub_m, rem_sub_m;
tile_swizzle_sub_m.divmod(m_tile_idx, quo_sub_m, rem_sub_m);
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem.get()))
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
? tile_swizzle_sub_m
: tile_swizzle_sub_m_rem;
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
tile_swizzle_sub_m.divmod(m_tile_idx, m_tile_idx_sub0, m_tile_idx_sub1);
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles.get();
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
sub_m_adapt.divmod(tile_idx_local, n_tile_idx_with_adapt, m_tile_idx_with_adapt);
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m.get(),
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
// sub_m_adapt.divmod(tile_idx_local, n_tile_idx_with_adapt, m_tile_idx_with_adapt);
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
};
......
......@@ -24,14 +24,36 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
// kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
// GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
M,
N,
K,
StrideA,
StrideB,
StrideC,
block_mapping,
static_cast<void*>(p_shared));
#else
ignore = karg;
ignore = b2c_map;
......@@ -40,9 +62,9 @@ __global__ void
template <index_t BlockSize,
typename Block2CTileMap_,
typename FloatAB,
typename FloatAB_,
typename FloatAcc,
typename FloatC,
typename FloatC_,
typename ALayout,
typename BLayout,
typename CLayout,
......@@ -95,8 +117,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static constexpr auto KPerBlock = K0PerBlock * K1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using FloatCShuffle = FloatAcc;
using Block2CTileMap = Block2CTileMap_;
using FloatAB = FloatAB_;
using FloatC = FloatC_;
struct Argument : public ck::tensor_operation::device::BaseArgument
{
......@@ -154,31 +179,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math::integer_divide_ceil(karg.M, MPerBlock),
karg.k_batch);
}
#if 0
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
#endif
__host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; }
......@@ -296,7 +296,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
c_block_size * sizeof(FloatC));
c_block_size * sizeof(FloatCShuffle));
}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
......@@ -384,30 +384,29 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
// __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
// {
// return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
// }
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
// using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block)
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
Block2CTileMap block_mapping,
void* __restrict__ p_shared_block)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
uint32_t m = karg.M;
uint32_t n = karg.N;
uint32_t k = karg.K;
uint32_t m = M;
uint32_t n = N;
uint32_t k = K;
uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
uint32_t stride_a = karg.StrideA;
uint32_t stride_b = karg.StrideB;
uint32_t stride_c = karg.StrideC;
uint32_t stride_a = StrideA;
uint32_t stride_b = StrideB;
uint32_t stride_c = StrideC;
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
......@@ -467,10 +466,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
auto& block_mapping = karg.block_mapping;
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx;
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
......@@ -608,7 +606,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block),
static_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
......@@ -662,7 +660,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
FloatCShuffle,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
......@@ -701,7 +699,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
......@@ -807,10 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure next loop LDS is ready for use
block_sync_lds();
}
// if(threadIdx.x == 0)
// printf("xxx bid:%d, xx_total_iter_length:%d \n", static_cast<int>(blockIdx.x),
// xx_total_iter_length);
}
template <typename Layout>
......
......@@ -162,7 +162,7 @@ struct MDiv
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment