Commit 8f571c0b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files modified for gemm with atomics and reduction

parent 5490b99c
......@@ -78,8 +78,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
BlockToCTileMap_GemmStreamK<MPerBlock,
NPerBlock,
K0PerBlock * K1,
// StreamKReductionStrategy::Atomic>,
StreamKReductionStrategy::Reduction>,
StreamKReductionStrategy::Atomic>,
// StreamKReductionStrategy::Reduction>,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
CDataType,
......@@ -149,10 +149,10 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
dim3 grid_dims = karg.block_mapping.sk_num_blocks;
printf("Recommended #stream-k blocks (assuming full GPU availability): %0d\n",
num_cu * occupancy);
num_cu = dev_prop.multiProcessorCount;
dim3 grid_dims =
(karg.block_mapping.sk_num_blocks ? karg.block_mapping.sk_num_blocks
: karg.block_mapping.reduction_start_block_idx);
float ave_time = 0;
// TODO: remove clear buffer for streamk kernels
......@@ -187,11 +187,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg.block_mapping.get_workspace_size_for_acc(
sizeof(typename GridwiseGemm::FloatAcc));
auto preprocess = [&]() {
hipGetErrorString(
hipMemsetAsync(workspace_semaphore,
0,
karg.block_mapping.get_workspace_size_for_semaphore(),
stream_config.stream_id_));
hipGetErrorString(hipMemsetAsync(
workspace_semaphore, 0, sizeof(num_cu), stream_config.stream_id_));
};
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
......@@ -282,8 +279,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation,
uint32_t NumSKBlocks = 0)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, NumSKBlocks};
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy),
NumSKBlocks};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -303,7 +319,15 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
CElementwiseOperation,
index_t NumSKBlocks = 0) override
{
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(reinterpret_cast<const ADataType*>(p_a),
reinterpret_cast<const BDataType*>(p_b),
reinterpret_cast<CDataType*>(p_c),
......@@ -313,6 +337,8 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
StrideA,
StrideB,
StrideC,
static_cast<uint32_t>(num_cu),
static_cast<uint32_t>(occupancy),
static_cast<uint32_t>(NumSKBlocks));
}
......
......@@ -1000,7 +1000,7 @@ struct BlockToCTileMap_GemmStreamK
//--------------------------------------
// pass to device
uint32_t sk_num_blocks;
mutable uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t dp_start_block_idx;
uint32_t reduction_start_block_idx;
......@@ -1011,7 +1011,12 @@ struct BlockToCTileMap_GemmStreamK
MDiv equiv_tiles_little; // for reduction
// prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m, uint32_t n, uint32_t k, uint32_t sk_blocks = 0)
BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
uint32_t k,
uint32_t num_cu,
uint32_t occupancy,
uint32_t sk_blocks = 0)
{
// total output tiles
uint32_t num_tiles =
......@@ -1019,10 +1024,24 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
const uint32_t one_wave = num_cu * occupancy;
// default to regular DP GEMM if sk blocks == 0
sk_num_blocks = sk_blocks;
if((sk_blocks > one_wave) && (num_tiles > one_wave))
{
printf("WARNING: Do not tune above max possible occupancy for the kernel, "
"defaulting to max occupancy\n ");
sk_num_blocks = one_wave;
}
else if(sk_blocks < one_wave)
{
printf("Recommended #stream-k blocks (assuming full GPU availability): %0d\n",
one_wave);
sk_num_blocks = sk_blocks;
}
else
sk_num_blocks = sk_blocks;
// default to regular DP GEMM if sk blocks == 0
if(sk_num_blocks == 0 || sk_num_blocks == 0xFFFFFFFF)
{
sk_num_blocks = 0;
......@@ -1064,7 +1083,7 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = sk_num_blocks;
dp_start_block_idx = ((sk_num_blocks + grid_size - 1) / grid_size) * grid_size;
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
......@@ -1079,15 +1098,15 @@ struct BlockToCTileMap_GemmStreamK
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
#if 0
printf("cu:%d, occupancy:%d, gridsize:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
#if 1
printf("num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu,
occupancy,
get_grid_dims(num_cu, occupancy).x,
// num_cu,
// occupancy,
// get_grid_dims(num_cu, occupancy).x,
num_tiles,
dp_tiles,
sk_num_big_blocks,
......
......@@ -23,19 +23,19 @@ namespace ck {
template <typename GridwiseGemm>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid,
void* p_workspace,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid,
void* p_workspace,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
typename GridwiseGemm::Block2CTileMap block_mapping)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
......@@ -145,6 +145,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t num_cu, occupancy; // stream-k arguments
Block2CTileMap block_mapping;
Argument(const FloatAB* p_a_grid_,
......@@ -156,6 +157,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
uint32_t num_cu_,
uint32_t occupancy_,
uint32_t num_sk_blocks_)
: p_a_grid(p_a_grid_),
p_b_grid(p_b_grid_),
......@@ -166,7 +169,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
block_mapping(M, N, K, num_sk_blocks_)
num_cu(num_cu_),
occupancy(occupancy_),
block_mapping(M, N, K, num_cu_, occupancy_, num_sk_blocks_)
{
}
......@@ -518,11 +523,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3();
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block, is_padding_block, is_reduction_block;
......@@ -555,11 +559,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // thread
// buf STORE
// descriptor
#pragma unroll
// stream-k: for new work for all the persistent blocks.
for(; block_idx < block_mapping.total_blocks_allocated(); block_idx += gridDim.x)
{
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
is_sk_block = block_idx < block_mapping.sk_num_blocks;
is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
......@@ -621,6 +629,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// start to compute
auto reduction_idx = block_idx - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
......@@ -666,6 +675,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL),
CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(0, block_mapping.sk_num_blocks);
#if 0
if(threadIdx.x == 0) {
......@@ -1142,6 +1153,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure next loop LDS is ready for use
block_sync_lds();
}
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(0);
// printf("block_idx=%0d, \n",block_idx);
}
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment