Commit ca8b5c79 authored by carlushuang's avatar carlushuang
Browse files

update reduction for streamk(not ready yet)

parent b2a49620
......@@ -111,6 +111,16 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
threadwise_transfer_.SetSrcSliceOrigin(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
threadwise_transfer_.SetDstSliceOrigin(dst_desc, dst_slice_origin_idx);
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
......
......@@ -141,7 +141,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
// TODO: remove clear buffer for streamk kernels
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
}
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_);
workspace_semaphore =
workspace_semaphore +
karg.block_mapping.get_workspace_size_for_acc(sizeof(GridwiseGemm::FloatAcc));
hipGetErrorString(hipMemset(
workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore()));
}
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -151,6 +165,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_workspace_,
karg.M,
karg.N,
karg.K,
......@@ -170,6 +185,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
}
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
return p_arg->block_mapping.get_workspace_size(sizeof(GridwiseGemm::FloatAcc));
}
else
{
return 0;
}
}
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
......
......@@ -637,9 +637,16 @@ struct BlockToCTileMap_3DGrid_KSplit
}
};
enum StreamKReductionStrategy
{
Atomic = 0, // sk block use atomic to do reduction
Reduction, // let some workgroup responsible for doing the reduction operation
};
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK
{
......@@ -647,34 +654,36 @@ struct BlockToCTileMap_GemmStreamK
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t sk_total_iters;
// uint32_t sk_total_iters;
uint32_t dp_start_block_idx;
uint32_t dp_iters_per_block;
uint32_t dp_num_blocks;
// uint32_t dp_iters_per_block;
// uint32_t dp_num_blocks;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block;
// uint32_t tiles_cover_big_blocks; // for reduction
// uint32_t total_acc_buffers; // for reduction
MDiv2 n_tiles;
MDiv k_iters_per_tile;
MDiv n_tiles;
MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
static int env_get_int(const char* var_name, int default_int)
{
char* v = getenv(var_name);
int r = default_int;
if(v)
r = atoi(v);
return r;
}
// prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
......@@ -727,8 +736,9 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles = partial_dispatche_tiles + num_cu;
}
dp_iters_per_block = k_iters_per_tile.get();
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
......@@ -775,7 +785,6 @@ struct BlockToCTileMap_GemmStreamK
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
sk_num_blocks = env_get_int("sk_num_blocks", sk_num_blocks);
if(sk_num_blocks == 0)
{
......@@ -807,7 +816,16 @@ struct BlockToCTileMap_GemmStreamK
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock));
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
// tile_swizzle_sub_m_rem =
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
......@@ -831,9 +849,28 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block);
}
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ dim3 get_grid_dims() const
{
return dim3(dp_start_block_idx + dp_num_blocks, 1, 1);
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
}
else
return dim3(reduction_start_block_idx, 1, 1);
}
__device__ uint32_t get_block_idx() const
......@@ -858,6 +895,8 @@ struct BlockToCTileMap_GemmStreamK
}
else if(block_idx >= dp_start_block_idx)
{
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
......@@ -882,10 +921,11 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t /*n*/) const
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{
uint32_t m_tile_idx, n_tile_idx;
n_tiles.divmod(tile_idx, m_tile_idx, n_tile_idx);
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// return make_tuple(m_tile_idx, n_tile_idx);
// swizzle tile
......@@ -901,7 +941,7 @@ struct BlockToCTileMap_GemmStreamK
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles.get();
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
......@@ -911,6 +951,115 @@ struct BlockToCTileMap_GemmStreamK
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
// __host__ __device__ uint32_t get_workspace_offset_for_semaphore() const
// {
// // workspace contains 2 part, 1) partial reduction buffer 2) semaphore for cross-wg sync
// // we let 1) start from offset:0, 2) start from the end of 1)
// // NOTE: offset is in unit of byte
// return get_total_acc_buffers() *
// }
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv& eqav_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
uint32_t quo_, rem_;
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_eqav_tiles_ + rem_;
// return tile_idx_ / eqav_tiles_ * max_eqav_tiles_ + (tile_idx_ % eqav_tiles_);
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
// printf("reverse tile:%u, %u/%u\n", tile_idx_little_reverse, touched_sk_blocks,
// current_intersec);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
// uint32_t touched_tiles = (block_idx_ * iters_per_big_sk_block + iters - 1) / iters;
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
// uint32_t touched_tiles = (block_idx_little_reverse * iters_per_little_sk_block +
// iters - 1) / iters;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
};
} // namespace ck
......@@ -14,8 +14,9 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck {
......@@ -28,6 +29,7 @@ __global__ void
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid,
void* p_workspace,
index_t M,
index_t N,
index_t K,
......@@ -45,6 +47,7 @@ __global__ void
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_workspace,
M,
N,
K,
......@@ -52,18 +55,26 @@ __global__ void
StrideB,
StrideC,
block_mapping,
static_cast<void*>(p_shared));
#else
ignore = karg;
ignore = b2c_map;
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_workspace;
ignore = M;
ignore = N;
ignore = K;
ignore = StrideA;
ignore = StrideB;
ignore = StrideC;
ignore = block_mapping;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <index_t BlockSize,
typename Block2CTileMap_,
typename FloatAB_,
typename FloatAcc,
typename FloatAcc_,
typename FloatC_,
typename ALayout,
typename BLayout,
......@@ -117,6 +128,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static constexpr auto KPerBlock = K0PerBlock * K1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using FloatAcc = FloatAcc_;
using FloatCShuffle = FloatAcc;
using Block2CTileMap = Block2CTileMap_;
......@@ -292,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB),
......@@ -372,7 +384,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
......@@ -384,11 +396,54 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat / CShuffleMRepeatPerShuffle>{},
Number<CShuffleMRepeatPerShuffle * MWave * MPerXDL>{},
Number<NRepeat / CShuffleNRepeatPerShuffle>{},
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
void* p_workspace,
index_t M,
index_t N,
index_t K,
......@@ -425,6 +480,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// ignore = p_workspace; // TODO: for reduction
// lds max alignment
constexpr auto max_lds_align = K1;
......@@ -468,16 +525,187 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start;
// if(threadIdx.x == 0)
// printf("xxx bid:%d, is_sk_block:%d, is_dp_block:%d\n", static_cast<int>(blockIdx.x),
// is_sk_block, is_dp_block);
if(!is_sk_block && !is_dp_block)
if(is_padding_block)
return;
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto MReduceIters =
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index(
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse =
make_multi_index(0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n = make_multi_index(
0,
0,
0,
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_n_reverse =
make_multi_index(0,
0,
0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL,
true>
acc_buf;
acc_buf.Clear();
// start to compute
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
FloatAcc, // SrcData,
FloatAcc, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_desc), // DstDesc,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder,
2, // SrcVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(static_cast<index_t>(tile_acc_offset_start), I0, I0)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, // SrcData,
FloatC, // DstData,
decltype(acc_thread_buf_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder,
2, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(spatial_idx[I0], I0, spatial_idx[I1], I0),
CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
static_for<0, MReduceIters, 1>{}([&](auto i_m_reduce) {
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
static_cast<FloatAcc*>(p_workspace) + i,
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_desc,
make_multi_index(I0),
parcial_acc_buf);
static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_desc.CalculateOffset(make_tuple(i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
acc_store.Run(acc_thread_buf_desc,
make_multi_index(I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_load_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
});
if constexpr(i_m_reduce != MReduceIters - 1)
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
});
return;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
......@@ -602,15 +830,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
auto c_partial_acc_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
static_cast<FloatAcc*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform(
make_tuple(CShuffleMRepeatPerShuffle,
......@@ -701,14 +936,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock,
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
......@@ -717,6 +952,32 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0),
c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step =
......@@ -757,19 +1018,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
else if(is_sk_block)
{
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
// constexpr offset
c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_tuple(mxdlperwave, I0, nxdlperwave, I0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf),
decltype(c_block_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
else if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_block_desc_mblock_mperblock_nblock_nperblock,
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
}
// move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep &&
......@@ -795,6 +1079,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
mxdlperwave_forward_step);
}
});
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
}
// exit condition
......@@ -802,6 +1097,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if(iter_end <= iter_start)
break;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use
block_sync_lds();
}
......
......@@ -178,21 +178,46 @@ struct MDiv
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ uint32_t div(uint32_t dividend) const
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend, multiplier, shift);
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend, uint32_t& quotient, uint32_t& remainder) const
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient = div(dividend);
remainder = dividend - (quotient * divisor);
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
__host__ __device__ uint32_t operator/(uint32_t dividend) const { return div(dividend); }
__host__ __device__ uint32_t get() const { return divisor; }
};
struct MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv2(uint32_t divisor_)
{
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ MDiv2() : multiplier(0), shift(0) {}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck
......@@ -240,5 +240,21 @@ struct less
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
return Y;
}
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two(Number<X> x)
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
return Number<Y>{};
}
} // namespace math
} // namespace ck
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
namespace ck {
struct workgroup_barrier {
__device__ workgroup_barrier(uint32_t * ptr) :
base_ptr(ptr)
{}
__device__ uint32_t ld(uint32_t offset)
{
#if 0
float d = llvm_amdgcn_raw_buffer_load_fp32(
amdgcn_make_buffer_resource(base_ptr),
0,
offset,
AMDGCN_BUFFER_GLC);
union cvt {
float f32;
uint32_t u32;
};
cvt x;
x.f32 = d;
return x.u32;
#endif
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
}
__device__ void wait_eq(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0){
while(ld(offset) != value){}
}
__syncthreads();
}
__device__ void wait_lt(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0){
while(ld(offset) < value){}
}
__syncthreads();
}
__device__ void wait_set(uint32_t offset, uint32_t compare, uint32_t value)
{
if(threadIdx.x == 0){
while(atomicCAS(base_ptr + offset, compare, value) != compare){}
}
__syncthreads();
}
// enter critical zoon, assume buffer is zero when launch kernel
__device__ void aquire(uint32_t offset)
{
wait_set(offset, 0, 1);
}
// exit critical zoon, assume buffer is zero when launch kernel
__device__ void release(uint32_t offset)
{
wait_set(offset, 1, 0);
}
__device__ void inc(uint32_t offset)
{
__syncthreads();
if(threadIdx.x == 0){
atomicAdd(base_ptr + offset, 1);
}
}
uint32_t * base_ptr;
};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment