"git@developer.sourcefind.cn:modelzoo/gemma-2_pytorch.git" did not exist on "9c692fdf040b276519d1ff158929a85c9af2e0d3"
Commit bc5d7b6a authored by Jing Zhang's avatar Jing Zhang
Browse files

remove zero

parent 5ca6b1f8
...@@ -222,11 +222,17 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -222,11 +222,17 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); std::size_t grouped_gemm_kernel_args_buf_size =
grouped_gemm_kernel_args_.size() * sizeof(GroupedGemmKernelArgument);
hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(), DeviceMem gemm_arg_dev_mem(grouped_gemm_kernel_args_buf_size);
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
gemm.GetWorkSpaceSize(&argument), grouped_gemm_kernel_args_buf_size,
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
...@@ -236,11 +242,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -236,11 +242,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch); gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
...@@ -273,16 +289,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -273,16 +289,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
} }
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass; return pass;
} }
...@@ -293,8 +299,10 @@ int main(int argc, char* argv[]) ...@@ -293,8 +299,10 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = { // problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; // 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 180, 184, 168, 156, 168, 148};
problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
......
#pragma once
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -41,6 +40,8 @@ __global__ void ...@@ -41,6 +40,8 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
uint32_t* barrier_count,
const index_t barrier_size_grp,
const index_t group_count, const index_t group_count,
const index_t grid_size_grp, const index_t grid_size_grp,
const index_t KBatch, const index_t KBatch,
...@@ -96,12 +97,26 @@ __global__ void ...@@ -96,12 +97,26 @@ __global__ void
}); });
index_t id_off = 0; index_t id_off = 0;
index_t id_local = get_block_1d_id() - BlockStart;
const index_t mn_blocks = local_grid_size / KBatch;
__shared__ index_t k_id_start, k_id_finished;
ignore = barrier_count;
ignore = k_id_start;
ignore = k_id_finished;
while((get_block_1d_id() - BlockStart + id_off) < local_grid_size) while(id_local < local_grid_size)
{ {
const auto block_2_etile_map = const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
auto barrier_count_start =
barrier_count + group_id * barrier_size_grp * 2 + id_local % mn_blocks;
auto barrier_count_finished = barrier_count + group_id * barrier_size_grp * 2 +
barrier_size_grp + id_local % mn_blocks;
GridwiseGemm::template Run<HasMainKBlockLoop, GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation, EGlobalMemoryDataOperation,
GemmSpec, GemmSpec,
...@@ -113,6 +128,8 @@ __global__ void ...@@ -113,6 +128,8 @@ __global__ void
p_ds_grid_, p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid, gemm_desc_ptr[group_id].p_e_grid,
p_shared, p_shared,
barrier_count_start,
barrier_count_finished,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -127,6 +144,7 @@ __global__ void ...@@ -127,6 +144,7 @@ __global__ void
block_2_etile_map); block_2_etile_map);
id_off += grid_size_grp; id_off += grid_size_grp;
id_local += grid_size_grp;
} }
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
...@@ -430,7 +448,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -430,7 +448,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
grid_size_ = grid_size_grp_ * group_count_; grid_size_ = grid_size_grp_ * group_count_;
} }
...@@ -568,6 +585,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -568,6 +585,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
group_id++; group_id++;
} }
const auto e_grid_desc_sum_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_);
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
} }
// private: // private:
...@@ -585,6 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -585,6 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t grid_size_; index_t grid_size_;
index_t grid_size_grp_; index_t grid_size_grp_;
index_t barrier_size_grp_;
index_t sum_of_m; index_t sum_of_m;
index_t k_batch_; index_t k_batch_;
...@@ -642,6 +668,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -642,6 +668,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
reinterpret_cast<uint32_t*>(arg.p_workspace_),
arg.barrier_size_grp_,
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_, arg.grid_size_grp_,
arg.k_batch_, arg.k_batch_,
...@@ -808,8 +836,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -808,8 +836,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * auto arg = *dynamic_cast<const Argument*>(p_arg);
sizeof(GroupedGemmKernelArgument<NumDTensor>);
return arg.group_count_ * (arg.barrier_size_grp_ * 2) * sizeof(uint32_t);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg)));
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......
...@@ -475,6 +475,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -475,6 +475,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
uint32_t* barrier_count_start,
uint32_t* barrier_count_finished,
const index_t KBatch,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation_& cde_element_op, const CDEElementwiseOperation_& cde_element_op,
...@@ -492,17 +495,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -492,17 +495,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor_>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
...@@ -661,8 +653,38 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -661,8 +653,38 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
// ignore = barrier_count_start;
// ignore = barrier_count_finished;
// ignore = KBatch;
__shared__ index_t k_id_start_shared;
if(threadIdx.x == 0)
{
const auto k_id_start_t = atomicAdd(barrier_count_start, 1);
k_id_start_shared = k_id_start_t;
if(k_id_start_t > 0)
{
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
}
}
__syncthreads();
// shuffle C and write out // shuffle C and write out
{ {
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor_>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!"); "wrong!");
...@@ -799,6 +821,34 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -799,6 +821,34 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}, },
Number<NumDTensor_>{})); Number<NumDTensor_>{}));
// space filling curve for threadwise C in VGPR before shuffle
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
if(k_id_start_shared == 0)
{
// blockwise copy C/D/E between LDS and global // blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
...@@ -807,8 +857,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -807,8 +857,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation_, CDEElementwiseOperation_,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence Sequence<static_cast<index_t>(InMemoryDataOperationEnum::Set)>, // FIXME: make
// support arbitray type // Sequence
// support
// arbitray type
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
...@@ -818,9 +870,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -818,9 +870,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t< sequence_merge_t<Sequence<true>,
Sequence<true>, uniform_sequence_gen_t<
uniform_sequence_gen_t<NumDTensor_, NumDTensor_,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs, {c_ds_desc_refs,
...@@ -829,31 +881,78 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -829,31 +881,78 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
cde_element_op}; cde_element_op};
// space filling curve for threadwise C in VGPR before shuffle static_for<0, num_access, 1>{}([&](auto access_id) {
constexpr auto sfc_c_vgpr = // make sure it's safe to write to LDS
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>, block_sync_lds();
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E // each thread write its data from VGPR to LDS
constexpr auto sfc_cde_block = c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>, sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
Sequence<0, 2, 1, 3>, c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor_, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
else
{
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation_,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
// Sequence support
// arbitray type
Sequence<1, Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1, 1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); 3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<Sequence<true>,
uniform_sequence_gen_t<
NumDTensor_,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)),
cde_element_op};
static_for<0, num_access, 1>{}([&](auto access_id) { static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS // make sure it's safe to write to LDS
...@@ -895,6 +994,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -895,6 +994,20 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
}); });
} }
__syncthreads();
if(threadIdx.x == 0)
{
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
if(k_id_finished_t == KBatch - 1)
{
*barrier_count_start = 0;
*barrier_count_finished = 0;
}
}
}
} }
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
...@@ -910,6 +1023,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -910,6 +1023,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer p_ds_grid, DsGridPointer p_ds_grid,
void* __restrict__ p_e_grid_, void* __restrict__ p_e_grid_,
void* __restrict__ p_shared, void* __restrict__ p_shared,
uint32_t* barrier_count_start,
uint32_t* barrier_count_finished,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation& cde_element_op,
...@@ -977,6 +1092,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -977,6 +1092,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_shared, p_shared,
barrier_count_start,
barrier_count_finished,
KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
...@@ -994,6 +1112,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -994,6 +1112,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid, p_ds_grid,
p_e_grid, p_e_grid,
p_shared, p_shared,
barrier_count_start,
barrier_count_finished,
KBatch,
a_element_op, a_element_op,
b_element_op, b_element_op,
ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment