"...composable_kernel_rocm.git" did not exist on "238d58c2f5947246a3e62f72db2b175b2e948554"
Commit 5b1e2442 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

2-tile sk+ DP with atomics for FP16

parent 2ae16e90
...@@ -21,3 +21,25 @@ Warm up ...@@ -21,3 +21,25 @@ Warm up
Start running 5 times... Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
``` ```
# Instructions for ```example_gemm_xdl_streamk```
## Run ```example_gemm_xdl_streamk```
```bash
# arg1: verification (0=no, 1=yes)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
# arg10: NumSKBlocks(optional, defaults to DP GEMM)
bin/example_gemm_xdl_streamk 1 2 1 3840 4096 4096 4096 4096 4096 312
```
Result (MI250 @ 1700Mhz, 181TFlops peak FP16 on 1 dye)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Recommended grid size :312
Perf: 1.21689 ms, 105.884 TFlops, 79.2748 GB/s, GemmXdlStreamK_RRR_B256_Vec8x2x8_128x128x4x8
```
...@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
"setting"); "setting");
} }
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3 grid_dims = karg.block_mapping.get_grid_dims(); dim3 grid_dims = karg.block_mapping.get_grid_dims();
float ave_time = 0; float ave_time = 0;
...@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
uint32_t NumSKBlocks = 0xffffffff) uint32_t NumSKBlocks = 0)
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>; const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu; int occupancy, num_cu;
hipError_t rtn; hip_check_error(
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hipDeviceProp_t dev_prop; hipDeviceProp_t dev_prop;
hipDevice_t dev; hipDevice_t dev;
rtn = hipGetDevice(&dev); hip_check_error(hipGetDevice(&dev));
hip_check_error(rtn); hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount; num_cu = dev_prop.multiProcessorCount;
printf("Assuming full GPU availability, recommended stream-k grid size for tuning :%0d\n",
num_cu * occupancy);
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>; const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu; int occupancy, num_cu;
hipError_t rtn; hip_check_error(
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hipDeviceProp_t dev_prop; hipDeviceProp_t dev_prop;
hipDevice_t dev; hipDevice_t dev;
rtn = hipGetDevice(&dev); hip_check_error(hipGetDevice(&dev));
hip_check_error(rtn); hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount; num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a), return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
......
...@@ -1010,142 +1010,69 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1010,142 +1010,69 @@ struct BlockToCTileMap_GemmStreamK
MDiv eqav_tiles_big; // for reduction MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host // prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m, BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n, uint32_t n,
uint32_t k, uint32_t k,
uint32_t num_cu, uint32_t num_cu,
uint32_t occupancy, uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff) uint32_t sk_blocks = 0)
{ {
// total output tiles
uint32_t num_tiles = uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
// one cu can hold one wg at one time, from the whole chip's point of view uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches. // default to regular DP GEMM if sk blocks == 0
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial sk_num_blocks = sk_blocks;
// dispatch) if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{ {
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... sk_num_blocks = 0;
// occupancy = 3, full_dispatches = 5, 8, 11 ... dp_tiles = num_tiles;
// occupancy = 4, full_dispatches = 7, 11 ... sk_num_big_blocks = 0;
sk_occupancy = 1; // left 1 slot for sk occupancy k_iters_per_big_block = 0;
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles; dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
} }
// 2-tile sk + DP GEMM
else else
{ {
// others, we reduce 1 dispatch from dp, together with partial dispatch, // grid size
// to construct sk dispatch uint32_t grid_size = occupancy * num_cu;
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); // check if there's enough work for DP+ stream-k
dp_tiles = full_dispatch_tiles - num_cu; bool bigEnough = num_tiles > grid_size;
sk_tiles = partial_dispatche_tiles + num_cu; // max of 2 sk tiles per block
uint32_t sk_tiles = bigEnough ? grid_size + num_tiles % grid_size : num_tiles;
// remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
} }
// uint32_t dp_iters_per_block = k_iters_per_tile.get(); n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles; // using multiple blocks for parallel reduction
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
{
sk_num_blocks = 0;
}
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
if(sk_num_blocks == 0)
{
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
else
{
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
...@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK
} }
#if 0 #if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " printf("cu:%d, occupancy:%d, gridsize:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, " "sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, " "sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, " "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n", "sk_tiles:%u, workspace(acc float):%u\n",
num_cu, num_cu,
occupancy, occupancy,
// get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x, get_grid_dims().x,
num_tiles, num_tiles,
dp_tiles, dp_tiles,
...@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK
sk_num_blocks, sk_num_blocks,
sk_total_iters, sk_total_iters,
dp_start_block_idx, dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks, dp_num_blocks,
k_iters_per_tile.get(), k_iters_per_tile.get(),
k_iters_per_big_block, k_iters_per_big_block,
...@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK
return k_iters_per_tile.div(sk_total_iters); return k_iters_per_tile.div(sk_total_iters);
} }
__host__ __device__ dim3 get_grid_dims() const // __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__ __device__ constexpr dim3 get_grid_dims() const
{ {
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{ {
...@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK
} }
else else
return dim3(reduction_start_block_idx, 1, 1); return dim3(reduction_start_block_idx, 1, 1);
// return dim3(num_cu * occupancy, 1, 1); // HS
}
__host__ __device__ uint32_t total_blocks_allocated() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return __builtin_amdgcn_readfirstlane(reduction_start_block_idx + get_sk_tiles());
}
else
return __builtin_amdgcn_readfirstlane(reduction_start_block_idx);
} }
__device__ uint32_t get_block_idx() const __device__ uint32_t get_block_idx() const
......
...@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t num_cu, occupancy; // stream-k arguments
Block2CTileMap block_mapping; Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_, Argument(const FloatAB* p_a_grid_,
...@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
uint32_t num_cu, uint32_t num_cu_,
uint32_t occupancy, uint32_t occupancy_,
uint32_t num_sk_blocks_) uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_), p_b_grid(p_b_grid_),
...@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA(StrideA_), StrideA(StrideA_),
StrideB(StrideB_), StrideB(StrideB_),
StrideC(StrideC_), StrideC(StrideC_),
block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_) num_cu(num_cu_),
occupancy(occupancy_),
block_mapping(M, N, K, num_cu_, occupancy_, num_sk_blocks_)
{ {
} }
...@@ -452,16 +455,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -452,16 +455,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Block2CTileMap block_mapping, Block2CTileMap block_mapping,
void* __restrict__ p_shared_block) void* __restrict__ p_shared_block)
{ {
uint32_t m = M; uint32_t m = M;
uint32_t n = N; uint32_t n = N;
uint32_t k = K; uint32_t k = K;
uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock; uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock; uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock; uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock;
uint32_t stride_a = StrideA; uint32_t stride_a = StrideA;
uint32_t stride_b = StrideB; uint32_t stride_b = StrideB;
uint32_t stride_c = StrideC; uint32_t stride_c = StrideC;
uint32_t block_idx = block_mapping.get_block_idx();
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a); const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b); const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
...@@ -520,623 +523,640 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -520,623 +523,640 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
if(is_padding_block)
return;
uint32_t* p_semaphore = uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) + reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc))); block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) // offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block, is_padding_block, is_reduction_block;
uint32_t total_iter_length;
#pragma unroll
// stream-k: for new work for all the persistent blocks.
for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x)
{ {
if(is_reduction_block)
is_sk_block = block_idx < block_mapping.sk_num_blocks;
is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
if(is_padding_block)
{
continue;
}
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
total_iter_length = iter_end - iter_start;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{ {
// descriptors is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
constexpr auto cluster_length_reduce = GetClusterLengthReduction(); if(is_reduction_block)
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); {
const auto reduce_thread_cluster_idx = // descriptors
reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); constexpr auto cluster_length_reduce = GetClusterLengthReduction();
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; const auto reduce_thread_cluster_idx = reduce_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
constexpr auto MReduceIters = const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0)); const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{}, constexpr auto MReduceIters = math::integer_divide_ceil(
cluster_length_reduce.At(I1) * Number<MPerBlock>{}, cluster_length_reduce.At(I0));
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}); constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( cluster_length_reduce.At(I1) *
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{});
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
constexpr auto partial_acc_load_step_n = make_multi_index( I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse = constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
make_multi_index(0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * constexpr auto partial_acc_load_step_n = make_multi_index(
CBlockTransferScalarPerVector_NWaveNPerXDL); 0,
constexpr auto partial_acc_load_step_m = cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
make_multi_index(cluster_length_reduce.At(I0), 0); constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
constexpr auto partial_acc_store_step_n = make_multi_index( -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
0, CBlockTransferScalarPerVector_NWaveNPerXDL);
0, constexpr auto partial_acc_load_step_m =
0, make_multi_index(cluster_length_reduce.At(I0), 0);
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_n_reverse = constexpr auto partial_acc_store_step_n = make_multi_index(
make_multi_index(0, 0,
0, 0,
0, 0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
CBlockTransferScalarPerVector_NWaveNPerXDL); constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
constexpr auto partial_acc_store_step_m = 0,
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); 0,
0,
StaticBuffer<AddressSpaceEnum::Vgpr, -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
FloatAcc, CBlockTransferScalarPerVector_NWaveNPerXDL);
CBlockTransferScalarPerVector_NWaveNPerXDL, constexpr auto partial_acc_store_step_m =
true> make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferScalarPerVector_NWaveNPerXDL,
true> true>
acc_buf; parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
// start to compute FloatAcc,
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx; CBlockTransferScalarPerVector_NWaveNPerXDL,
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); true>
acc_buf;
workgroup_barrier wg_barrier(p_semaphore);
// start to compute
uint32_t tile_acc_offset_start = auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx;
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx); auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
uint32_t tile_acc_offset_end =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1); workgroup_barrier wg_barrier(p_semaphore);
auto acc_load = ThreadwiseTensorSliceTransfer_v2< uint32_t tile_acc_offset_start =
FloatAcc, // SrcData, block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
FloatAcc, // DstData, uint32_t tile_acc_offset_end =
decltype(c_partial_acc_block_m_n), // SrcDesc, block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, auto acc_load = ThreadwiseTensorSliceTransfer_v2<
Sequence<0, 1>, // DimAccessOrder, FloatAcc, // SrcData,
1, // SrcVectorDim, FloatAcc, // DstData,
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector, decltype(c_partial_acc_block_m_n), // SrcDesc,
1, // SrcScalarStrideInVector, decltype(acc_thread_buf_load_desc), // DstDesc,
false // SrcResetCoordinateAfterRun, Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
>{c_partial_acc_block_m_n, Sequence<0, 1>, // DimAccessOrder,
make_multi_index(thread_m_cluster_id, 1, // SrcVectorDim,
thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL)}; 1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< >{c_partial_acc_block_m_n,
FloatAcc, // SrcData, make_multi_index(thread_m_cluster_id,
FloatC, // DstData, thread_n_cluster_id *
decltype(acc_thread_buf_store_desc), // SrcDesc, CBlockTransferScalarPerVector_NWaveNPerXDL)};
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation, auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, FloatAcc, // SrcData,
Sequence<0, 1, 2, 3>, // DimAccessOrder, FloatC, // DstData,
3, // DstVectorDim, decltype(acc_thread_buf_store_desc), // SrcDesc,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, CElementwiseOperation, // ElementwiseOperation,
1, // DstScalarStrideInVector, Sequence<1,
false // DstResetCoordinateAfterRun, 1,
>{c_grid_desc_mblock_mperblock_nblock_nperblock, 1,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
thread_m_cluster_id, Sequence<0, 1, 2, 3>, // DimAccessOrder,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]), 3, // DstVectorDim,
thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL), InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
CElementwiseOperation{}}; 1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
// block synchronization >{c_grid_desc_mblock_mperblock_nblock_nperblock,
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL),
CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
#if 0 #if 0
if(threadIdx.x == 0) { if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x), printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block_idx),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]), __builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1])); __builtin_amdgcn_readfirstlane(spatial_idx[I1]));
} }
#endif #endif
using Accumulation = ck::detail:: using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>; AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
for(int i_m = 0; i_m < MReduceIters; i_m++) for(int i_m = 0; i_m < MReduceIters; i_m++)
{ {
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear(); acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{ {
auto c_partial_acc_buf = auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>( AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<FloatAcc*>(p_workspace) + reinterpret_cast<FloatAcc*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(), i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize()); c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n, acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf, c_partial_acc_buf,
acc_thread_buf_load_desc, acc_thread_buf_load_desc,
make_tuple(I0, I0), make_tuple(I0, I0),
parcial_acc_buf); parcial_acc_buf);
static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}( static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}(
[&](auto i_vec) { [&](auto i_vec) {
constexpr auto offset = constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset( acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec)); make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}), Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]); parcial_acc_buf[Number<offset>{}]);
}); });
} }
if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL < if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
NPerBlock) NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{ {
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, acc_store.Run(acc_thread_buf_store_desc,
partial_acc_load_step_n); make_tuple(I0, I0, I0, I0),
acc_store.MoveDstSliceWindow( acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n); c_grid_buf);
} }
else if constexpr(NReduceIters != 1)
{ {
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, if constexpr(i_n_reduce != (NReduceIters - 1))
partial_acc_load_step_n_reverse); {
acc_store.MoveDstSliceWindow( acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
c_grid_desc_mblock_mperblock_nblock_nperblock, partial_acc_load_step_n);
partial_acc_store_step_n_reverse); acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
} }
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
} }
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
} }
return;
} }
return;
} }
} while(true)
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
uint32_t tile_idx, iter_offset;
block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
const index_t k0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
const index_t num_k_block_main_loop = current_iter_length;
gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
a_block_desc_k0_m_k1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_k0_n_k1_grid_desc,
b_block_desc_k0_n_k1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// output: register to global memory
{ {
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
uint32_t tile_idx, iter_offset;
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock);
constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); const index_t n_block_data_idx_on_grid =
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock);
constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); const index_t k0_block_data_idx_on_grid =
constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); // A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle = ThisThreadBlock,
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle(); AElementwiseOperation,
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(
make_tuple(CShuffleMRepeatPerShuffle,
M1,
M2,
M3,
M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform(I0), // freeze nblock
make_unmerge_transform(
make_tuple(CShuffleNRepeatPerShuffle,
N1,
N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{},
Sequence<0, 2, 4, 5, 6>{},
Sequence<>{},
Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatCShuffle,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle, InMemoryDataOperationEnum::Set,
CShuffleNRepeatPerShuffle, Sequence<K0PerBlock, MPerBlock, K1>,
I1, ABlockTransferThreadClusterLengths_K0_M_K1,
I1, ABlockTransferThreadClusterArrangeOrder,
M2, FloatAB,
I1, FloatAB,
M4, decltype(a_k0_m_k1_grid_desc),
I1>, decltype(a_block_desc_k0_m_k1),
Sequence<0, 1, 2, 3, 4, 5, 6, 7>, ABlockTransferSrcAccessOrder,
7, Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k0_m_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1, 1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, BThreadTransferSrcResetCoordinateAfterRun,
make_multi_index(0, true>(b_k0_n_k1_grid_desc,
0, make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
m_thread_data_on_block_idx[I1], b_element_op,
n_thread_data_on_block_idx[I1], b_block_desc_k0_n_k1,
m_thread_data_on_block_idx[I2], make_multi_index(0, 0, 0),
m_thread_data_on_block_idx[I3], ck::tensor_operation::element_wise::PassThrough{});
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]), const index_t num_k_block_main_loop = current_iter_length;
ck::tensor_operation::element_wise::PassThrough{}};
gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
// LDS to global a_block_desc_k0_m_k1,
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< a_blockwise_copy,
ThisThreadBlock, // index_t BlockSize, a_grid_buf,
CElementwiseOperation, // ElementwiseOperation, a_block_buf,
// InMemoryDataOperationEnum::Set, // DstInMemOp, a_block_slice_copy_step,
Sequence<1, b_k0_n_k1_grid_desc,
CShuffleMRepeatPerShuffle * MWave * MPerXDL, b_block_desc_k0_n_k1,
1, b_blockwise_copy,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, b_grid_buf,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, b_block_buf,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, b_block_slice_copy_step,
FloatCShuffle, // typename SrcData, blockwise_gemm,
FloatC, // typename DstData, c_thread_buf,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), num_k_block_main_loop);
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, // output: register to global memory
3, // index_t VectorDim, {
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
false, // bool ThreadTransferSrcResetCoordinateAfterRun, constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle, constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
make_multi_index(0, 0, 0, 0), blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
0, blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0), constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
c_element_op}; constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
// LDS to global partial acc constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
ThisThreadBlock, // index_t BlockSize, constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
CElementwiseOperation, // ElementwiseOperation, constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
// InMemoryDataOperationEnum::Set, // DstInMemOp, constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL, constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
1, GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle();
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle();
FloatCShuffle, // typename SrcData,
FloatCShuffle, // typename DstData, auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), reinterpret_cast<FloatCShuffle*>(p_shared_block),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, auto c_partial_acc_buf =
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false, reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
// othre wise has scratch c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false, .GetElementSpaceSize());
// othre wise has scratch
{c_block_desc_mblock_mpershuffle_nblock_npershuffle, constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
make_multi_index(0, 0, 0, 0), transform_tensor_descriptor(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
c_block_copy_lds_to_global.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0)); make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(
// LDS to global make_tuple(CShuffleMRepeatPerShuffle,
if(is_dp_block) M1,
c_block_copy_lds_to_global.template Run<decltype(c_block_buf), M2,
decltype(c_grid_buf), M3,
InMemoryDataOperationEnum::Set>( M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL
c_block_desc_mblock_mpershuffle_nblock_npershuffle, make_freeze_transform(I0), // freeze nblock
c_block_buf, make_unmerge_transform(
c_grid_desc_mblock_mperblock_nblock_nperblock, make_tuple(CShuffleNRepeatPerShuffle,
c_grid_buf); N1,
else if(is_sk_block) N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL
{ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
if constexpr(Block2CTileMap::ReductionStrategy == make_tuple(Sequence<>{},
StreamKReductionStrategy::Reduction) Sequence<0, 2, 4, 5, 6>{},
{ Sequence<>{},
// constexpr offset Sequence<1, 3, 7>{}));
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// VGPR to LDS
auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatCShuffle,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// LDS to global
auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
0,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatCShuffle, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL);
constexpr auto nxdlperwave_backward_step =
make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL);
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}(
[&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
constexpr index_t nxdlperwave_value =
nxdlperwave_forward_sweep
? nxdlperwave_iter
: (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
// make sure it's safe to do ds_write
block_sync_lds();
// VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_block_buf);
// make sure it's safe to do ds_read
block_sync_lds();
c_block_copy_lds_to_global.SetSrcSliceOrigin(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(0, 0, 0, 0)); make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( // LDS to global
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, if(is_dp_block)
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0)); c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
c_block_copy_lds_to_partial_acc decltype(c_grid_buf),
.template Run<decltype(c_block_buf), InMemoryDataOperationEnum::Set>(
decltype(c_partial_acc_buf), c_block_desc_mblock_mpershuffle_nblock_npershuffle,
InMemoryDataOperationEnum::Set>( c_block_buf,
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_block_buf, c_grid_buf);
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, else if(is_sk_block)
c_partial_acc_buf); {
} if constexpr(Block2CTileMap::ReductionStrategy ==
else if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
StreamKReductionStrategy::Atomic) {
{ // constexpr offset
c_block_copy_lds_to_global c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
.template Run<decltype(c_block_buf), c_block_desc_mblock_mpershuffle_nblock_npershuffle,
decltype(c_grid_buf), make_tuple(0, 0, 0, 0));
InMemoryDataOperationEnum::AtomicAdd>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_buf, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc.template Run<
decltype(c_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
else if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
}
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); nxdlperwave_forward_step);
} }
} else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
}
});
// move on nxdlperwave dimension // move on mxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep && if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
(nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_forward_step); mxdlperwave_forward_step);
}
else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
{
c_block_copy_lds_to_global.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
nxdlperwave_backward_step);
} }
}); });
// move on mxdlperwave dimension if constexpr(Block2CTileMap::ReductionStrategy ==
if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) StreamKReductionStrategy::Reduction)
{ {
c_block_copy_lds_to_global.MoveDstSliceWindow( if(is_sk_block)
c_grid_desc_mblock_mperblock_nblock_nperblock, {
mxdlperwave_forward_step); // increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
} }
}); }
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap::ReductionStrategy == if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction) StreamKReductionStrategy::Reduction)
{ {
if(is_sk_block) block_acc_offset -= MPerBlock * NPerBlock;
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
} }
// make sure next loop LDS is ready for use
block_sync_lds();
} }
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment