Commit 5490b99c authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

2 tile streamk withreduction

parent 5b1e2442
...@@ -78,7 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -78,7 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
BlockToCTileMap_GemmStreamK<MPerBlock, BlockToCTileMap_GemmStreamK<MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock * K1, K0PerBlock * K1,
StreamKReductionStrategy::Atomic>, // StreamKReductionStrategy::Atomic>,
StreamKReductionStrategy::Reduction>,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
...@@ -139,11 +140,20 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -139,11 +140,20 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
// stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy // stream-k: calculate the number of blocks to be launched based on #CUs and #occupancy
// dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy); // dim3 grid_dims = karg.block_mapping.get_grid_dims(karg.num_cu, karg.occupancy);
dim3 grid_dims = karg.block_mapping.get_grid_dims();
float ave_time = 0;
int occupancy, num_cu;
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>; const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
dim3 grid_dims = karg.block_mapping.sk_num_blocks;
printf("Recommended #stream-k blocks (assuming full GPU availability): %0d\n",
num_cu * occupancy);
float ave_time = 0;
// TODO: remove clear buffer for streamk kernels // TODO: remove clear buffer for streamk kernels
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
...@@ -272,30 +282,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -272,30 +282,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation, CElementwiseOperation,
uint32_t NumSKBlocks = 0) uint32_t NumSKBlocks = 0)
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
printf("Assuming full GPU availability, recommended stream-k grid size for tuning :%0d\n",
num_cu * occupancy);
return Argument{p_a, return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, NumSKBlocks};
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy),
NumSKBlocks};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -315,15 +303,6 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -315,15 +303,6 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation, CElementwiseOperation,
index_t NumSKBlocks = 0) override index_t NumSKBlocks = 0) override
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a), return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
reinterpret_cast<const BDataType*>(p_b), reinterpret_cast<const BDataType*>(p_b),
...@@ -334,8 +313,6 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -334,8 +313,6 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy),
static_cast<uint32_t>(NumSKBlocks)); static_cast<uint32_t>(NumSKBlocks));
} }
......
...@@ -1007,16 +1007,11 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1007,16 +1007,11 @@ struct BlockToCTileMap_GemmStreamK
uint32_t k_iters_per_big_block; uint32_t k_iters_per_big_block;
MDiv2 n_tiles; MDiv2 n_tiles;
MDiv k_iters_per_tile; MDiv k_iters_per_tile;
MDiv eqav_tiles_big; // for reduction MDiv equiv_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction MDiv equiv_tiles_little; // for reduction
// prefer construct on host // prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m, BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t sk_blocks = 0)
uint32_t n,
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0)
{ {
// total output tiles // total output tiles
uint32_t num_tiles = uint32_t num_tiles =
...@@ -1027,6 +1022,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1027,6 +1022,7 @@ struct BlockToCTileMap_GemmStreamK
// default to regular DP GEMM if sk blocks == 0 // default to regular DP GEMM if sk blocks == 0
sk_num_blocks = sk_blocks; sk_num_blocks = sk_blocks;
if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF) if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF)
{ {
sk_num_blocks = 0; sk_num_blocks = 0;
...@@ -1042,7 +1038,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1042,7 +1038,7 @@ struct BlockToCTileMap_GemmStreamK
else else
{ {
// grid size // grid size
uint32_t grid_size = occupancy * num_cu; uint32_t grid_size = sk_num_blocks;
// check if there's enough work for DP+ stream-k // check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size; bool bigEnough = num_tiles > grid_size;
// max of 2 sk tiles per block // max of 2 sk tiles per block
...@@ -1068,7 +1064,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1068,7 +1064,7 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block = k_iters_per_sk_block + 1; k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles; dp_num_blocks = dp_tiles;
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; dp_start_block_idx = sk_num_blocks;
} }
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
...@@ -1079,8 +1075,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1079,8 +1075,8 @@ struct BlockToCTileMap_GemmStreamK
{ {
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
} }
#if 0 #if 0
...@@ -1091,8 +1087,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1091,8 +1087,7 @@ struct BlockToCTileMap_GemmStreamK
"sk_tiles:%u, workspace(acc float):%u\n", "sk_tiles:%u, workspace(acc float):%u\n",
num_cu, num_cu,
occupancy, occupancy,
// get_grid_dims(num_cu, occupancy).x, get_grid_dims(num_cu, occupancy).x,
get_grid_dims().x,
num_tiles, num_tiles,
dp_tiles, dp_tiles,
sk_num_big_blocks, sk_num_big_blocks,
...@@ -1124,15 +1119,9 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1124,15 +1119,9 @@ struct BlockToCTileMap_GemmStreamK
} }
// __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const // __host__ __device__ constexpr dim3 get_grid_dims(int num_cu, int occupancy) const
__host__ __device__ constexpr dim3 get_grid_dims() const __host__ __device__ constexpr dim3 get_grid_dims(uint32_t num_cu, uint32_t occupancy) const
{ {
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) return dim3(num_cu * occupancy, 1, 1);
{
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
}
else
return dim3(reduction_start_block_idx, 1, 1);
// return dim3(num_cu * occupancy, 1, 1); // HS
} }
__host__ __device__ uint32_t total_blocks_allocated() const __host__ __device__ uint32_t total_blocks_allocated() const
{ {
...@@ -1240,13 +1229,13 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1240,13 +1229,13 @@ struct BlockToCTileMap_GemmStreamK
} }
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& eqav_tiles_) const const MDiv& equiv_tiles_) const
{ {
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1; uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
uint32_t quo_, rem_; uint32_t quo_, rem_;
eqav_tiles_.divmod(tile_idx_, quo_, rem_); equiv_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_eqav_tiles_ + rem_; return quo_ * max_equiv_tiles_ + rem_;
} }
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
...@@ -1264,9 +1253,9 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1264,9 +1253,9 @@ struct BlockToCTileMap_GemmStreamK
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big = uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big); get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
uint32_t total_intersec_little = uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little); get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little; return sk_num_blocks + total_intersec_big + total_intersec_little;
} }
...@@ -1281,7 +1270,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1281,7 +1270,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t touched_sk_blocks = uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block; k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big); uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
return touched_sk_blocks + current_intersec; return touched_sk_blocks + current_intersec;
} }
else else
...@@ -1292,7 +1281,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1292,7 +1281,7 @@ struct BlockToCTileMap_GemmStreamK
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block; iters_per_little_sk_block;
uint32_t current_intersec = uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little); get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
} }
} }
...@@ -1305,7 +1294,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1305,7 +1294,7 @@ struct BlockToCTileMap_GemmStreamK
{ {
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1); k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big); uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
return block_idx_ + current_intersec; return block_idx_ + current_intersec;
} }
else else
...@@ -1313,7 +1302,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1313,7 +1302,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
uint32_t touched_tiles = k_iters_per_tile.div( uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little); uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
} }
} }
......
...@@ -145,7 +145,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -145,7 +145,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t num_cu, occupancy; // stream-k arguments
Block2CTileMap block_mapping; Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_, Argument(const FloatAB* p_a_grid_,
...@@ -157,8 +156,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -157,8 +156,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
uint32_t num_cu_,
uint32_t occupancy_,
uint32_t num_sk_blocks_) uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_), p_b_grid(p_b_grid_),
...@@ -169,9 +166,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -169,9 +166,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA(StrideA_), StrideA(StrideA_),
StrideB(StrideB_), StrideB(StrideB_),
StrideC(StrideC_), StrideC(StrideC_),
num_cu(num_cu_), block_mapping(M, N, K, num_sk_blocks_)
occupancy(occupancy_),
block_mapping(M, N, K, num_cu_, occupancy_, num_sk_blocks_)
{ {
} }
...@@ -523,9 +518,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -523,9 +518,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
// offset for last acc buffer of this block // offset for last acc buffer of this block
uint32_t block_acc_offset = uint32_t block_acc_offset =
...@@ -536,6 +528,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -536,6 +528,33 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t total_iter_length; uint32_t total_iter_length;
constexpr auto cluster_length_reduce =
GetClusterLengthReduction(); // get nperblock, mperblock for reduction
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{},
cluster_length_reduce.At(I0)); // calculate total Mreduce iterations for block
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}); // calculate
// total Nreduce
// iterations for
// block
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // thread
// buf LOAD
// descriptor
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // thread
// buf STORE
// descriptor
#pragma unroll #pragma unroll
// stream-k: for new work for all the persistent blocks. // stream-k: for new work for all the persistent blocks.
for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x) for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x)
...@@ -558,28 +577,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -558,28 +577,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{ {
is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx; is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
if(is_reduction_block) if(is_reduction_block)
{ {
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx = reduce_desc.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
...@@ -622,8 +622,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -622,8 +622,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx; auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start = uint32_t tile_acc_offset_start =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx); block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end = uint32_t tile_acc_offset_end =
...@@ -669,9 +667,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -669,9 +667,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferScalarPerVector_NWaveNPerXDL), CBlockTransferScalarPerVector_NWaveNPerXDL),
CElementwiseOperation{}}; CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
#if 0 #if 0
if(threadIdx.x == 0) { if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block_idx), printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(block_idx),
...@@ -750,9 +745,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -750,9 +745,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
partial_acc_store_step_m); partial_acc_store_step_m);
} }
} }
return; continue;
} }
} }
while(true) while(true)
{ {
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
...@@ -1131,17 +1127,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1131,17 +1127,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
mxdlperwave_forward_step); mxdlperwave_forward_step);
} }
}); });
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} }
// exit condition // exit condition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment