Commit 8f571c0b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files modified for gemm with atomics and reduction

parent 5490b99c
...@@ -78,8 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -78,8 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
BlockToCTileMap_GemmStreamK<MPerBlock, BlockToCTileMap_GemmStreamK<MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock * K1, K0PerBlock * K1,
// StreamKReductionStrategy::Atomic>, StreamKReductionStrategy::Atomic>,
StreamKReductionStrategy::Reduction>, // StreamKReductionStrategy::Reduction>,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CDataType, CDataType,
...@@ -149,10 +149,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -149,10 +149,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
hipDevice_t dev; hipDevice_t dev;
hip_check_error(hipGetDevice(&dev)); hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount; num_cu = dev_prop.multiProcessorCount;
dim3 grid_dims = karg.block_mapping.sk_num_blocks; dim3 grid_dims =
printf("Recommended #stream-k blocks (assuming full GPU availability): %0d\n", (karg.block_mapping.sk_num_blocks ? karg.block_mapping.sk_num_blocks
num_cu * occupancy); : karg.block_mapping.reduction_start_block_idx);
float ave_time = 0; float ave_time = 0;
// TODO: remove clear buffer for streamk kernels // TODO: remove clear buffer for streamk kernels
...@@ -187,11 +187,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -187,11 +187,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg.block_mapping.get_workspace_size_for_acc( karg.block_mapping.get_workspace_size_for_acc(
sizeof(typename GridwiseGemm::FloatAcc)); sizeof(typename GridwiseGemm::FloatAcc));
auto preprocess = [&]() { auto preprocess = [&]() {
hipGetErrorString( hipGetErrorString(hipMemsetAsync(
hipMemsetAsync(workspace_semaphore, workspace_semaphore, 0, sizeof(num_cu), stream_config.stream_id_));
0,
karg.block_mapping.get_workspace_size_for_semaphore(),
stream_config.stream_id_));
}; };
ave_time = launch_and_time_kernel_with_preprocess(stream_config, ave_time = launch_and_time_kernel_with_preprocess(stream_config,
...@@ -282,8 +279,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -282,8 +279,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation, CElementwiseOperation,
uint32_t NumSKBlocks = 0) uint32_t NumSKBlocks = 0)
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, NumSKBlocks}; int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy),
NumSKBlocks};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -303,7 +319,15 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -303,7 +319,15 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation, CElementwiseOperation,
index_t NumSKBlocks = 0) override index_t NumSKBlocks = 0) override
{ {
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a), return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
reinterpret_cast<const BDataType*>(p_b), reinterpret_cast<const BDataType*>(p_b),
reinterpret_cast<CDataType*>(p_c), reinterpret_cast<CDataType*>(p_c),
...@@ -313,6 +337,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -313,6 +337,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy),
static_cast<uint32_t>(NumSKBlocks)); static_cast<uint32_t>(NumSKBlocks));
} }
......
...@@ -1000,7 +1000,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1000,7 +1000,7 @@ struct BlockToCTileMap_GemmStreamK
//-------------------------------------- //--------------------------------------
// pass to device // pass to device
uint32_t sk_num_blocks; mutable uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks; uint32_t sk_num_big_blocks;
uint32_t dp_start_block_idx; uint32_t dp_start_block_idx;
uint32_t reduction_start_block_idx; uint32_t reduction_start_block_idx;
...@@ -1011,7 +1011,12 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1011,7 +1011,12 @@ struct BlockToCTileMap_GemmStreamK
MDiv equiv_tiles_little; // for reduction MDiv equiv_tiles_little; // for reduction
// prefer construct on host // prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t sk_blocks = 0) BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0)
{ {
// total output tiles // total output tiles
uint32_t num_tiles = uint32_t num_tiles =
...@@ -1019,10 +1024,24 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1019,10 +1024,24 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
uint32_t dp_tiles, dp_num_blocks, sk_total_iters; uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
const uint32_t one_wave = num_cu * occupancy;
// default to regular DP GEMM if sk blocks == 0 if((sk_blocks > one_wave) && (num_tiles > one_wave))
sk_num_blocks = sk_blocks; {
printf("WARNING: Do not tune above max possible occupancy for the kernel, "
"defaulting to max occupancy\n ");
sk_num_blocks = one_wave;
}
else if(sk_blocks < one_wave)
{
printf("Recommended #stream-k blocks (assuming full GPU availability): %0d\n",
one_wave);
sk_num_blocks = sk_blocks;
}
else
sk_num_blocks = sk_blocks;
// default to regular DP GEMM if sk blocks == 0
if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF) if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF)
{ {
sk_num_blocks = 0; sk_num_blocks = 0;
...@@ -1064,7 +1083,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1064,7 +1083,7 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block = k_iters_per_sk_block + 1; k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles; dp_num_blocks = dp_tiles;
dp_start_block_idx = sk_num_blocks; dp_start_block_idx = ((sk_num_blocks + grid_size - 1) / grid_size) * grid_size;
} }
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
...@@ -1079,15 +1098,15 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1079,15 +1098,15 @@ struct BlockToCTileMap_GemmStreamK
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
} }
#if 0 #if 1
printf("cu:%d, occupancy:%d, gridsize:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " printf("num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, " "sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, " "sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, " "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n", "sk_tiles:%u, workspace(acc float):%u\n",
num_cu, // num_cu,
occupancy, // occupancy,
get_grid_dims(num_cu, occupancy).x, // get_grid_dims(num_cu, occupancy).x,
num_tiles, num_tiles,
dp_tiles, dp_tiles,
sk_num_big_blocks, sk_num_big_blocks,
......
...@@ -23,19 +23,19 @@ namespace ck { ...@@ -23,19 +23,19 @@ namespace ck {
template <typename GridwiseGemm> template <typename GridwiseGemm>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid, const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid, typename GridwiseGemm::FloatC* p_c_grid,
void* p_workspace, void* p_workspace,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping) typename GridwiseGemm::Block2CTileMap block_mapping)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
...@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t num_cu, occupancy; // stream-k arguments
Block2CTileMap block_mapping; Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_, Argument(const FloatAB* p_a_grid_,
...@@ -156,6 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -156,6 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
uint32_t num_cu_,
uint32_t occupancy_,
uint32_t num_sk_blocks_) uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_), p_b_grid(p_b_grid_),
...@@ -166,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -166,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA(StrideA_), StrideA(StrideA_),
StrideB(StrideB_), StrideB(StrideB_),
StrideC(StrideC_), StrideC(StrideC_),
block_mapping(M, N, K, num_sk_blocks_) num_cu(num_cu_),
occupancy(occupancy_),
block_mapping(M, N, K, num_cu_, occupancy_, num_sk_blocks_)
{ {
} }
...@@ -518,11 +523,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -518,11 +523,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
uint32_t iter_start, iter_end; uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block, is_padding_block, is_reduction_block; bool is_sk_block, is_dp_block, is_padding_block, is_reduction_block;
...@@ -555,11 +559,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -555,11 +559,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // thread I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // thread
// buf STORE // buf STORE
// descriptor // descriptor
#pragma unroll
// stream-k: for new work for all the persistent blocks. // stream-k: for new work for all the persistent blocks.
for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x) for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x)
{ {
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
is_sk_block = block_idx < block_mapping.sk_num_blocks; is_sk_block = block_idx < block_mapping.sk_num_blocks;
is_dp_block = block_idx >= block_mapping.dp_start_block_idx && is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx; block_idx < block_mapping.reduction_start_block_idx;
...@@ -621,6 +629,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -621,6 +629,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// start to compute // start to compute
auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx; auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start = uint32_t tile_acc_offset_start =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx); block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
...@@ -666,6 +675,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -666,6 +675,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
thread_n_cluster_id * thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL), CBlockTransferScalarPerVector_NWaveNPerXDL),
CElementwiseOperation{}}; CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(0, block_mapping.sk_num_blocks);
#if 0 #if 0
if(threadIdx.x == 0) { if(threadIdx.x == 0) {
...@@ -1142,6 +1153,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -1142,6 +1153,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure next loop LDS is ready for use // make sure next loop LDS is ready for use
block_sync_lds(); block_sync_lds();
} }
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
// printf("block_idx=%0d, \n",block_idx);
}
}
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment