Commit 07bac859 authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

add dync wg id

parent 4feebedd
...@@ -32,13 +32,25 @@ __global__ void ...@@ -32,13 +32,25 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const index_t group_count,
const index_t num_wg,
index_t* block_id_count)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
const index_t block_id = get_block_1d_id(); __shared__ index_t block_id_share;
if(get_thread_local_1d_id() == 0)
{
block_id_share = atomic_add(block_id_count, 1);
}
block_sync_lds();
const index_t block_id = block_id_share;
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
...@@ -63,7 +75,13 @@ __global__ void ...@@ -63,7 +75,13 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_, gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_); gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id);
if(get_thread_local_1d_id() == 0 && block_id == num_wg - 1)
{
*block_id_count = 0;
}
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -408,6 +426,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -408,6 +426,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
index_t* block_id_count;
hip_check_error(hipMalloc(&block_id_count, sizeof(index_t)));
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
...@@ -421,6 +442,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -421,6 +442,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
} }
hip_check_error(hipMemset(block_id_count, 0, sizeof(index_t)));
ave_time = ave_time =
launch_and_time_kernel(stream_config, launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -428,7 +451,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -428,7 +451,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size()); arg.gemm_kernel_args_.size(),
arg.grid_size_,
block_id_count);
}; };
if(all_have_main_k0_block_loop) if(all_have_main_k0_block_loop)
......
...@@ -498,7 +498,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -498,7 +498,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
typename Block2CTileMap> typename Block2CTileMap>
__device__ static void Run(const Argument& karg, __device__ static void Run(const Argument& karg,
void* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map,
const index_t block_id)
{ {
const FloatAB* p_a_grid = karg.p_a_grid; const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid; const FloatAB* p_b_grid = karg.p_b_grid;
...@@ -525,7 +526,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -525,7 +526,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// divide block work by [KBatch, M, N] // divide block work by [KBatch, M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); // block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
if(!block_2_ctile_map.ValidCTileIndex( if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx, block_work_idx,
...@@ -724,53 +726,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -724,53 +726,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
#if 0
// preload data into LDS
{
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
}
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainKBlockLoop)
{
index_t k0_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step);
a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf);
k0_block_data_begin += K0PerBlock;
} while(k0_block_data_begin < (K0 - K0PerBlock));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// gridwise GEMM pipeline // gridwise GEMM pipeline
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>(); GridwiseGemmPipeline_Selector<PipelineVersion::v2, 1, LoopScheduler::Default>();
...@@ -794,7 +749,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -794,7 +749,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
#endif
// output: register to global memory // output: register to global memory
{ {
......
...@@ -12,8 +12,10 @@ cmake ...@@ -12,8 +12,10 @@ cmake
-save-temps=$PWD" \ -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment