Commit 5b1e2442 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

2-tile sk+ DP with atomics for FP16

parent 2ae16e90
......@@ -21,3 +21,25 @@ Warm up
Start running 5 times...
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
```
# Instructions for ```example_gemm_xdl_streamk```
## Run ```example_gemm_xdl_streamk```
```bash
# arg1: verification (0=no, 1=yes)
# arg2: initialization (0=no init, 1=integer value, 2=decimal value)
# arg3: time kernel (0=no, 1=yes)
# arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
# arg10: NumSKBlocks(optional, defaults to DP GEMM)
bin/example_gemm_xdl_streamk 1 2 1 3840 4096 4096 4096 4096 4096 312
```
Result (MI250 @ 1700Mhz, 181TFlops peak FP16 on 1 dye)
```
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Recommended grid size :312
Perf: 1.21689 ms, 105.884 TFlops, 79.2748 GB/s, GemmXdlStreamK_RRR_B256_Vec8x2x8_128x128x4x8
```
......@@ -137,6 +137,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
"setting");
}
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3 grid_dims = karg.block_mapping.get_grid_dims();
float ave_time = 0;
......@@ -268,22 +270,19 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
uint32_t NumSKBlocks = 0xffffffff)
uint32_t NumSKBlocks = 0)
{
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
printf("Assuming full GPU availability, recommended stream-k grid size for tuning :%0d\n",
num_cu * occupancy);
return Argument{p_a,
p_b,
......@@ -318,17 +317,12 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
{
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte());
hip_check_error(rtn);
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
......
......@@ -1010,113 +1010,27 @@ struct BlockToCTileMap_GemmStreamK
MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
// prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0xffffffff)
uint32_t sk_blocks = 0)
{
// total output tiles
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
// one cu can hold one wg at one time, from the whole chip's point of view
// if number of wg is same as num_cu, we call it 1 dispatch
// if number of wg is 2x num_cu, we call it 2 dispatches.
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
// dispatch)
//
uint32_t full_dispatches = num_tiles / num_cu;
uint32_t full_dispatch_tiles = full_dispatches * num_cu;
uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles;
uint32_t sk_occupancy = occupancy;
uint32_t dp_tiles = full_dispatch_tiles;
uint32_t sk_tiles = partial_dispatche_tiles;
if(full_dispatches < occupancy)
{
// in this case, we allocate all blocks as sk blocks
// sk_occupancy = occupancy - full_dispatches;
sk_occupancy = 1; // TODO: single occ seems better
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
{
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
// occupancy = 3, full_dispatches = 5, 8, 11 ...
// occupancy = 4, full_dispatches = 7, 11 ...
sk_occupancy = 1; // left 1 slot for sk occupancy
dp_tiles = full_dispatch_tiles;
sk_tiles = partial_dispatche_tiles;
}
else
{
// others, we reduce 1 dispatch from dp, together with partial dispatch,
// to construct sk dispatch
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
dp_tiles = full_dispatch_tiles - num_cu;
sk_tiles = partial_dispatche_tiles + num_cu;
}
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
uint32_t max_sk_tiles =
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
: math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
// if use dp for sk-block, how many iters do we need
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{
uint32_t tentative_sk_iters_per_block =
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
// TODO: carefully adjust this parameter
// the more sk_blocks_per_tile, the worse the overhead
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
if(tentative_sk_blocks % sk_tiles != 0)
{
// penalty for uneven divide
cross_sk_blocks_overhead +=
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
}
uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
if(tentative_sk_score < best_sk_score)
{
best_sk_score = tentative_sk_score;
sk_num_blocks = tentative_sk_blocks;
}
}
if(best_sk_score >= dp_for_sk_iters)
// default to regular DP GEMM if sk blocks == 0
sk_num_blocks = sk_blocks;
if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF)
{
sk_num_blocks = 0;
}
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
if(sk_num_blocks == 0)
{
dp_tiles = num_tiles;
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
......@@ -1124,8 +1038,20 @@ struct BlockToCTileMap_GemmStreamK
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
// 2-tile sk + DP GEMM
else
{
// grid size
uint32_t grid_size = occupancy * num_cu;
// check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size;
// max of 2 sk tiles per block
uint32_t sk_tiles = bigEnough ? grid_size + num_tiles % grid_size : num_tiles;
// remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
......@@ -1144,8 +1070,9 @@ struct BlockToCTileMap_GemmStreamK
dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
// using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
......@@ -1157,13 +1084,14 @@ struct BlockToCTileMap_GemmStreamK
}
#if 0
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf("cu:%d, occupancy:%d, gridsize:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
occupancy,
// get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x,
num_tiles,
dp_tiles,
......@@ -1171,7 +1099,7 @@ struct BlockToCTileMap_GemmStreamK
sk_num_blocks,
sk_total_iters,
dp_start_block_idx,
dp_iters_per_block,
dp_num_blocks,
k_iters_per_tile.get(),
k_iters_per_big_block,
......@@ -1195,7 +1123,8 @@ struct BlockToCTileMap_GemmStreamK
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ dim3 get_grid_dims() const
// __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__ __device__ constexpr dim3 get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
......@@ -1203,6 +1132,16 @@ struct BlockToCTileMap_GemmStreamK
}
else
return dim3(reduction_start_block_idx, 1, 1);
// return dim3(num_cu * occupancy, 1, 1); // HS
}
__host__ __device__ uint32_t total_blocks_allocated() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return __builtin_amdgcn_readfirstlane(reduction_start_block_idx + get_sk_tiles());
}
else
return __builtin_amdgcn_readfirstlane(reduction_start_block_idx);
}
__device__ uint32_t get_block_idx() const
......
......@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t num_cu, occupancy; // stream-k arguments
Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_,
......@@ -156,8 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
uint32_t num_cu,
uint32_t occupancy,
uint32_t num_cu_,
uint32_t occupancy_,
uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
......@@ -168,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_)
num_cu(num_cu_),
occupancy(occupancy_),
block_mapping(M, N, K, num_cu_, occupancy_, num_sk_blocks_)
{
}
......@@ -461,7 +464,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t stride_a = StrideA;
uint32_t stride_b = StrideB;
uint32_t stride_c = StrideC;
uint32_t block_idx = block_mapping.get_block_idx();
const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a);
const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c);
......@@ -520,39 +523,53 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
bool is_sk_block, is_dp_block, is_padding_block, is_reduction_block;
uint32_t total_iter_length;
#pragma unroll
// stream-k: for new work for all the persistent blocks.
for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x)
{
is_sk_block = block_idx < block_mapping.sk_num_blocks;
is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
if(is_padding_block)
return;
{
continue;
}
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
total_iter_length = iter_end - iter_start;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto reduce_thread_cluster_idx = reduce_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters =
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
......@@ -560,15 +577,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index(
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse =
make_multi_index(0,
0,
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m =
......@@ -579,8 +598,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0,
0,
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_n_reverse =
make_multi_index(0,
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
......@@ -600,7 +619,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
acc_buf;
// start to compute
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
......@@ -632,7 +651,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<1,
1,
1,
CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
......@@ -652,7 +674,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block_idx),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
......@@ -723,19 +745,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
return;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -755,8 +772,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
__builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
......@@ -776,8 +793,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k0_m_k1_grid_desc,
true>(a_k0_m_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
......@@ -785,8 +801,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
......@@ -806,8 +822,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k0_n_k1_grid_desc,
true>(b_k0_n_k1_grid_desc,
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_k0_n_k1,
......@@ -868,7 +883,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(
......@@ -994,10 +1010,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false,
// othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false,
// othre wise has scratch
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
......@@ -1014,7 +1030,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
constexpr auto mxdlperwave = mxdlperwave_iter;
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}(
[&](auto nxdlperwave_iter) {
constexpr bool nxdlperwave_forward_sweep =
(mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
......@@ -1045,7 +1062,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// LDS to global
if(is_dp_block)
c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
......@@ -1066,8 +1084,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf),
c_block_copy_lds_to_partial_acc.template Run<
decltype(c_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
......@@ -1131,7 +1149,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
......@@ -1139,6 +1158,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
block_sync_lds();
}
}
}
template <typename Layout>
struct LStr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment