Commit 5b1e2442 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

2-tile sk+ DP with atomics for FP16

parent 2ae16e90
...@@ -21,3 +21,25 @@ Warm up ...@@ -21,3 +21,25 @@ Warm up
Start running 5 times... Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
``` ```
# Instructions for ```example_gemm_xdl_streamk```
## Run ```example_gemm_xdl_streamk```
```bash
# arg1: verification (0=no, 1=yes)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
# arg10: NumSKBlocks(optional, defaults to DP GEMM)
bin/example_gemm_xdl_streamk 1 2 1 3840 4096 4096 4096 4096 4096 312
```
Result (MI250 @ 1700Mhz, 181TFlops peak FP16 on 1 dye)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Recommended grid size :312
Perf: 1.21689 ms, 105.884 TFlops, 79.2748 GB/s, GemmXdlStreamK_RRR_B256_Vec8x2x8_128x128x4x8
```
...@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
"setting"); "setting");
} }
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3 grid_dims = karg.block_mapping.get_grid_dims(); dim3 grid_dims = karg.block_mapping.get_grid_dims();
float ave_time = 0; float ave_time = 0;
...@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
uint32_t NumSKBlocks = 0xffffffff) uint32_t NumSKBlocks = 0)
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>; const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu; int occupancy, num_cu;
hipError_t rtn; hip_check_error(
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hipDeviceProp_t dev_prop; hipDeviceProp_t dev_prop;
hipDevice_t dev; hipDevice_t dev;
rtn = hipGetDevice(&dev); hip_check_error(hipGetDevice(&dev));
hip_check_error(rtn); hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount; num_cu = dev_prop.multiProcessorCount;
printf("Assuming full GPU availability, recommended stream-k grid size for tuning :%0d\n",
num_cu * occupancy);
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>; const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu; int occupancy, num_cu;
hipError_t rtn; hip_check_error(
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hipDeviceProp_t dev_prop; hipDeviceProp_t dev_prop;
hipDevice_t dev; hipDevice_t dev;
rtn = hipGetDevice(&dev); hip_check_error(hipGetDevice(&dev));
hip_check_error(rtn); hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount; num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a), return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
......
...@@ -1010,113 +1010,27 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1010,113 +1010,27 @@ struct BlockToCTileMap_GemmStreamK
MDiv eqav_tiles_big; // for reduction MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host // prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m, BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n, uint32_t n,
uint32_t k, uint32_t k,
uint32_t num_cu, uint32_t num_cu,
uint32_t occupancy, uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff) uint32_t sk_blocks = 0)
{ {
// total output tiles
uint32_t num_tiles = uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
// one cu can hold one wg at one time, from the whole chip's point of view uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy = 1; // left 1 slot for sk occupancy
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
dp_tiles = full_dispatch_tiles - num_cu;
sk_tiles = partial_dispatche_tiles + num_cu;
}
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
// TODO: carefully adjust this parameter // default to regular DP GEMM if sk blocks == 0
// the more sk_blocks_per_tile, the worse the overhead sk_num_blocks = sk_blocks;
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF)
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
{ {
sk_num_blocks = 0; sk_num_blocks = 0;
} dp_tiles = num_tiles;
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
if(sk_num_blocks == 0)
{
sk_num_big_blocks = 0; sk_num_big_blocks = 0;
k_iters_per_big_block = 0; k_iters_per_big_block = 0;
...@@ -1124,8 +1038,20 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1124,8 +1038,20 @@ struct BlockToCTileMap_GemmStreamK
dp_start_block_idx = 0; dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles sk_total_iters = 0; // clear this tiles
} }
// 2-tile sk + DP GEMM
else else
{ {
// grid size
uint32_t grid_size = occupancy * num_cu;
// check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size;
// max of 2 sk tiles per block
uint32_t sk_tiles = bigEnough ? grid_size + num_tiles % grid_size : num_tiles;
// remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles. // k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block // we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block // let m = k_iters_per_sk_block
...@@ -1144,8 +1070,9 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1144,8 +1070,9 @@ struct BlockToCTileMap_GemmStreamK
dp_num_blocks = dp_tiles; dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
} }
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
// using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
...@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK
} }
#if 0 #if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " printf("cu:%d, occupancy:%d, gridsize:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, " "sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, " "sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, " "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n", "sk_tiles:%u, workspace(acc float):%u\n",
num_cu, num_cu,
occupancy, occupancy,
// get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x, get_grid_dims().x,
num_tiles, num_tiles,
dp_tiles, dp_tiles,
...@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK
sk_num_blocks, sk_num_blocks,
sk_total_iters, sk_total_iters,
dp_start_block_idx, dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks, dp_num_blocks,
k_iters_per_tile.get(), k_iters_per_tile.get(),
k_iters_per_big_block, k_iters_per_big_block,
...@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK
return k_iters_per_tile.div(sk_total_iters); return k_iters_per_tile.div(sk_total_iters);
} }
__host__ __device__ dim3 get_grid_dims() const // __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__ __device__ constexpr dim3 get_grid_dims() const
{ {
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{ {
...@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK
} }
else else
return dim3(reduction_start_block_idx, 1, 1); return dim3(reduction_start_block_idx, 1, 1);
// return dim3(num_cu * occupancy, 1, 1); // HS
}
__host__ __device__ uint32_t total_blocks_allocated() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return __builtin_amdgcn_readfirstlane(reduction_start_block_idx + get_sk_tiles());
}
else
return __builtin_amdgcn_readfirstlane(reduction_start_block_idx);
} }
__device__ uint32_t get_block_idx() const __device__ uint32_t get_block_idx() const
......
...@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t num_cu, occupancy; // stream-k arguments
Block2CTileMap block_mapping; Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_, Argument(const FloatAB* p_a_grid_,
...@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
uint32_t num_cu, uint32_t num_cu_,
uint32_t occupancy, uint32_t occupancy_,
uint32_t num_sk_blocks_) uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_), p_b_grid(p_b_grid_),
...@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA(StrideA_), StrideA(StrideA_),
StrideB(StrideB_), StrideB(StrideB_),
StrideC(StrideC_), StrideC(StrideC_),
block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_) num_cu(num_cu_),
occupancy(occupancy_),
block_mapping(M, N, K, num_cu_, occupancy_, num_sk_blocks_)
{ {
} }
...@@ -461,7 +464,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -461,7 +464,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t stride_a = StrideA; uint32_t stride_a = StrideA;
uint32_t stride_b = StrideB; uint32_t stride_b = StrideB;
uint32_t stride_c = StrideC; uint32_t stride_c = StrideC;
uint32_t block_idx = block_mapping.get_block_idx();
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a); const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b); const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
...@@ -520,39 +523,53 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -520,39 +523,53 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
uint32_t block_idx = block_mapping.get_block_idx(); // offset for last acc buffer of this block
bool is_sk_block = block_idx < block_mapping.sk_num_blocks; uint32_t block_acc_offset =
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx && (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
block_idx < block_mapping.reduction_start_block_idx; NPerBlock;
bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end; uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end); bool is_sk_block, is_dp_block, is_padding_block, is_reduction_block;
uint32_t total_iter_length = iter_end - iter_start;
uint32_t total_iter_length;
#pragma unroll
// stream-k: for new work for all the persistent blocks.
for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x)
{
is_sk_block = block_idx < block_mapping.sk_num_blocks;
is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
if(is_padding_block) if(is_padding_block)
return; {
continue;
}
uint32_t* p_semaphore = block_mapping.get_block_itr(block_idx, iter_start, iter_end);
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) + total_iter_length = iter_end - iter_start;
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{ {
is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
if(is_reduction_block) if(is_reduction_block)
{ {
// descriptors // descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction(); constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx = const auto reduce_thread_cluster_idx = reduce_desc.CalculateBottomIndex(
reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = constexpr auto MReduceIters = math::integer_divide_ceil(
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0)); Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil( constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{}, Number<NPerBlock>{},
cluster_length_reduce.At(I1) * cluster_length_reduce.At(I1) *
...@@ -560,15 +577,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -560,15 +577,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_store_desc =
make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index( constexpr auto partial_acc_load_step_n = make_multi_index(
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); 0,
constexpr auto partial_acc_load_step_n_reverse = cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
make_multi_index(0, constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL); CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m = constexpr auto partial_acc_load_step_m =
...@@ -579,8 +598,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -579,8 +598,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0, 0,
0, 0,
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_n_reverse = constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
make_multi_index(0, 0,
0, 0,
0, 0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
...@@ -600,7 +619,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -600,7 +619,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
acc_buf; acc_buf;
// start to compute // start to compute
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx; auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore); workgroup_barrier wg_barrier(p_semaphore);
...@@ -632,7 +651,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -632,7 +651,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
decltype(acc_thread_buf_store_desc), // SrcDesc, decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, Sequence<1,
1,
1,
CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder, Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim, 3, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
...@@ -652,7 +674,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -652,7 +674,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
#if 0 #if 0
if(threadIdx.x == 0) { if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x), printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block_idx),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]), __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1])); __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
...@@ -723,19 +745,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -723,19 +745,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{ {
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m); partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock, acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m); partial_acc_store_step_m);
} }
} }
return; return;
} }
} }
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
while(true) while(true)
{ {
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
...@@ -755,8 +772,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -755,8 +772,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
__builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock); __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -776,8 +793,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -776,8 +793,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(a_k0_m_k1_grid_desc,
a_k0_m_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op, a_element_op,
a_block_desc_k0_m_k1, a_block_desc_k0_m_k1,
...@@ -785,8 +801,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -785,8 +801,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -806,8 +822,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -806,8 +822,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(b_k0_n_k1_grid_desc,
b_k0_n_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
b_element_op, b_element_op,
b_block_desc_k0_n_k1, b_block_desc_k0_n_k1,
...@@ -868,7 +883,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -868,7 +883,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize()); .GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(make_freeze_transform(I0), // freeze mblock make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform( make_unmerge_transform(
...@@ -994,10 +1010,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -994,10 +1010,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false, false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// othre wise has scratch // false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false, false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// othre wise has scratch // false, othre wise has scratch
{c_block_desc_mblock_mpershuffle_nblock_npershuffle, {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
...@@ -1014,7 +1030,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1014,7 +1030,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter; constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}(
[&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep = constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
...@@ -1045,7 +1062,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1045,7 +1062,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// LDS to global // LDS to global
if(is_dp_block) if(is_dp_block)
c_block_copy_lds_to_global.template Run<decltype(c_block_buf), c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf), decltype(c_grid_buf),
InMemoryDataOperationEnum::Set>( InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
...@@ -1066,8 +1084,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1066,8 +1084,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0)); make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc c_block_copy_lds_to_partial_acc.template Run<
.template Run<decltype(c_block_buf), decltype(c_block_buf),
decltype(c_partial_acc_buf), decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>( InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
...@@ -1131,7 +1149,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1131,7 +1149,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if(iter_end <= iter_start) if(iter_end <= iter_start)
break; break;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{ {
block_acc_offset -= MPerBlock * NPerBlock; block_acc_offset -= MPerBlock * NPerBlock;
} }
...@@ -1139,6 +1158,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1139,6 +1158,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
block_sync_lds(); block_sync_lds();
} }
} }
}
template <typename Layout> template <typename Layout>
struct LStr struct LStr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment