Commit 6c5111b7 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

code changes for streamk with reduction

parent ab3885aa
......@@ -147,8 +147,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
}
const auto Run = [&](const auto& kernel) {
dim3 grid_dim;
if(arg.Grid_size < 0)
......@@ -195,9 +199,30 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
else
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
else if constexpr(GridwiseGemm::Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore =
reinterpret_cast<char*>(arg.p_workspace_) +
arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
sizeof(GemmAccDataType));
auto preprocess = [&]() {
hipGetErrorString(hipMemsetAsync(
workspace_semaphore,
0,
arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
stream_config.stream_id_));
};
ave_time = launch_and_time_kernel_with_preprocess(
stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
}
};
......
......@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......@@ -30,15 +32,21 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared,
karg,
karg.p_workspace_,
karg.block_2_ctile_map_streamk);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -51,9 +59,9 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
......@@ -468,20 +476,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << ", Stream-K Selection:" << Streamk_sel
<< ", Grid size:" << Grid_size << "}" << std::endl;
}
......@@ -521,7 +520,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, Streamk_sel_, Grid_size_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
p_c_grid{p_c_grid_},
block_2_ctile_map_streamk(
M_, N_, AK0Number * CalculateKPadded(K_, 1), Grid_size_, Streamk_sel_)
{
}
......@@ -529,6 +530,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Reduction,
8,
4>
block_2_ctile_map_streamk;
};
struct SplitKBatchOffset
......@@ -853,6 +861,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MXdlPerWave / CShuffleMXdlPerWavePerShuffle>{},
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
Number<NXdlPerWave / CShuffleNXdlPerWavePerShuffle>{},
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
......@@ -1118,12 +1139,41 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Atomic,
8,
4>;
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using Block2CTileMap_streamk =
BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Reduction,
8,
4>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -1132,21 +1182,30 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
Problem& problem)
Problem& problem,
void* p_workspace,
Block2CTileMap_streamk block_2_ctile_map_streamk)
{
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel);
// Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
// problem.N,
// AK0Number * problem.KPadded,
// problem.Grid_size,
// problem.Streamk_sel);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block;
bool is_sk_block, is_dp_block, is_reduction_block;
index_t num_k_block_main_loop;
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
......@@ -1159,10 +1218,224 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
static_cast<uint32_t>(block_idx) <
block_2_ctile_map_streamk.reduction_start_block_idx;
is_reduction_block = static_cast<uint32_t>(block_idx) >=
block_2_ctile_map_streamk.reduction_start_block_idx;
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
uint32_t* p_semaphore = reinterpret_cast<uint32_t*>(
reinterpret_cast<char*>(p_workspace) +
block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType)));
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); // HS
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = math::integer_divide_ceil(
Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) *
Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{});
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto acc_thread_buf_store_desc =
make_naive_tensor_descriptor_packed(make_tuple(
I1, I1, I1, Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n =
make_multi_index(0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_n_reverse = make_multi_index(
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n =
make_multi_index(0,
0,
0,
cluster_length_reduce.At(I1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_n_reverse = make_multi_index(
0,
0,
0,
-1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) *
CShuffleBlockTransferScalarPerVector_NPerBlock);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
AccDataType,
CShuffleBlockTransferScalarPerVector_NPerBlock,
true>
acc_buf;
// start to compute
auto reduction_idx =
blockIdx.x - block_2_ctile_map_streamk.reduction_start_block_idx;
auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial(
reduction_idx, problem.M, problem.N);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx +
1);
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
AccDataType, // SrcData,
AccDataType, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_load_desc), // DstDesc,
Sequence<1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1>, // DimAccessOrder,
1, // SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
AccDataType, // SrcData,
CDataType, // DstData,
decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<1,
1,
1,
CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder,
3, // DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]),
thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock),
CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, AccDataType>;
for(int i_m = 0; i_m < MReduceIters; i_m++)
{
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
acc_buf.Clear();
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global,
AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_load_desc,
make_tuple(I0, I0),
parcial_acc_buf);
static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_load_desc.CalculateOffset(
make_tuple(0, i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
if(thread_n_cluster_id *
CShuffleBlockTransferScalarPerVector_NPerBlock <
NPerBlock)
{
acc_store.Run(acc_thread_buf_store_desc,
make_tuple(I0, I0, I0, I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
if constexpr(NReduceIters != 1)
{
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
}
});
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
}
return;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) *
MPerBlock * NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -1185,14 +1458,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
problem.NPadded,
problem.StrideB,
problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -1363,11 +1628,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); // HS
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.GetElementSpaceSize());
auto c_partial_acc_buf =
make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
reinterpret_cast<AccDataType*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize()); // HS
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
......@@ -1477,7 +1751,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_m_id, 0, block_n_id, 0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave *
NPerXdl>, // BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
CShuffleDataType, // typename SrcData,
CDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
......@@ -1535,15 +1836,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else if(is_sk_block)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.template Run<decltype(c_shuffle_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
else if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
make_tuple(0, 0, 0, 0));
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); // HS
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_shuffle_block_buf),
decltype(c_partial_acc_buf),
InMemoryDataOperationEnum::Set>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
c_shuffle_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
}
if constexpr(access_id < num_access - 1)
......@@ -1555,11 +1881,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
}
});
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
}
// exit condition
iter_end -= current_iter_length;
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap_streamk::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment