Commit ca8b5c79 authored by carlushuang's avatar carlushuang
Browse files

update reduction for streamk(not ready yet)

parent b2a49620
...@@ -111,6 +111,16 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2 ...@@ -111,6 +111,16 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
} }
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
threadwise_transfer_.SetSrcSliceOrigin(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
threadwise_transfer_.SetDstSliceOrigin(dst_desc, dst_slice_origin_idx);
}
private: private:
static constexpr auto thread_cluster_desc_ = static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
......
...@@ -141,7 +141,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -141,7 +141,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>; const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
// TODO: remove clear buffer for streamk kernels // TODO: remove clear buffer for streamk kernels
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
}
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_);
workspace_semaphore =
workspace_semaphore +
karg.block_mapping.get_workspace_size_for_acc(sizeof(GridwiseGemm::FloatAcc));
hipGetErrorString(hipMemset(
workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore()));
}
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -151,6 +165,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -151,6 +165,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg.p_a_grid, karg.p_a_grid,
karg.p_b_grid, karg.p_b_grid,
karg.p_c_grid, karg.p_c_grid,
karg.p_workspace_,
karg.M, karg.M,
karg.N, karg.N,
karg.K, karg.K,
...@@ -170,6 +185,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -170,6 +185,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
} }
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
return p_arg->block_mapping.get_workspace_size(sizeof(GridwiseGemm::FloatAcc));
}
else
{
return 0;
}
}
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
}
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
......
...@@ -637,44 +637,53 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -637,44 +637,53 @@ struct BlockToCTileMap_3DGrid_KSplit
} }
}; };
enum StreamKReductionStrategy
{
Atomic = 0, // sk block use atomic to do reduction
Reduction, // let some workgroup responsible for doing the reduction operation
};
template <uint32_t MPerBlock_, template <uint32_t MPerBlock_,
uint32_t NPerBlock_, uint32_t NPerBlock_,
uint32_t KPerBlock_, uint32_t KPerBlock_,
uint32_t TileSwizzleSubM_ = 8> StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK struct BlockToCTileMap_GemmStreamK
{ {
static constexpr uint32_t min_k_iters_per_sk_block = 2; static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_; static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_; static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_; static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//-------------------------------------- //--------------------------------------
// pass to device // pass to device
uint32_t sk_num_blocks; uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks; uint32_t sk_num_big_blocks;
uint32_t sk_total_iters; // uint32_t sk_total_iters;
uint32_t dp_start_block_idx; uint32_t dp_start_block_idx;
uint32_t dp_iters_per_block; // uint32_t dp_iters_per_block;
uint32_t dp_num_blocks; // uint32_t dp_num_blocks;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block; uint32_t k_iters_per_big_block;
// uint32_t tiles_cover_big_blocks; // for reduction
// uint32_t total_acc_buffers; // for reduction
MDiv2 n_tiles;
MDiv k_iters_per_tile; MDiv k_iters_per_tile;
MDiv n_tiles;
MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem; // MDiv tile_swizzle_sub_m_rem;
//-------------------------------------- //--------------------------------------
static int env_get_int(const char* var_name, int default_int)
{
char* v = getenv(var_name);
int r = default_int;
if(v)
r = atoi(v);
return r;
}
// prefer construct on host // prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m, BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n, uint32_t n,
...@@ -727,8 +736,9 @@ struct BlockToCTileMap_GemmStreamK ...@@ -727,8 +736,9 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles = partial_dispatche_tiles + num_cu; sk_tiles = partial_dispatche_tiles + num_cu;
} }
dp_iters_per_block = k_iters_per_tile.get(); uint32_t dp_iters_per_block = k_iters_per_tile.get();
sk_total_iters = k_iters_per_tile.get() * sk_tiles; uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{ {
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
...@@ -775,7 +785,6 @@ struct BlockToCTileMap_GemmStreamK ...@@ -775,7 +785,6 @@ struct BlockToCTileMap_GemmStreamK
// give a chance to control num of sk blocks // give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks; sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
sk_num_blocks = env_get_int("sk_num_blocks", sk_num_blocks);
if(sk_num_blocks == 0) if(sk_num_blocks == 0)
{ {
...@@ -807,7 +816,16 @@ struct BlockToCTileMap_GemmStreamK ...@@ -807,7 +816,16 @@ struct BlockToCTileMap_GemmStreamK
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
} }
} }
n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock)); n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
// tile_swizzle_sub_m_rem = // tile_swizzle_sub_m_rem =
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m); // MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
...@@ -831,9 +849,28 @@ struct BlockToCTileMap_GemmStreamK ...@@ -831,9 +849,28 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block); k_iters_per_big_block);
} }
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ dim3 get_grid_dims() const __host__ __device__ dim3 get_grid_dims() const
{ {
return dim3(dp_start_block_idx + dp_num_blocks, 1, 1); if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
}
else
return dim3(reduction_start_block_idx, 1, 1);
} }
__device__ uint32_t get_block_idx() const __device__ uint32_t get_block_idx() const
...@@ -858,6 +895,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -858,6 +895,8 @@ struct BlockToCTileMap_GemmStreamK
} }
else if(block_idx >= dp_start_block_idx) else if(block_idx >= dp_start_block_idx)
{ {
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block; iter_end = iter_start + dp_iters_per_block;
} }
...@@ -882,10 +921,11 @@ struct BlockToCTileMap_GemmStreamK ...@@ -882,10 +921,11 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile.divmod(iter, tile_idx, iter_offset); k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
} }
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t /*n*/) const __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{ {
uint32_t m_tile_idx, n_tile_idx; uint32_t m_tile_idx, n_tile_idx;
n_tiles.divmod(tile_idx, m_tile_idx, n_tile_idx); uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// return make_tuple(m_tile_idx, n_tile_idx); // return make_tuple(m_tile_idx, n_tile_idx);
// swizzle tile // swizzle tile
...@@ -901,7 +941,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -901,7 +941,7 @@ struct BlockToCTileMap_GemmStreamK
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles.get(); uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
...@@ -911,6 +951,115 @@ struct BlockToCTileMap_GemmStreamK ...@@ -911,6 +951,115 @@ struct BlockToCTileMap_GemmStreamK
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt); n_tile_idx_with_adapt);
} }
// __host__ __device__ uint32_t get_workspace_offset_for_semaphore() const
// {
// // workspace contains 2 part, 1) partial reduction buffer 2) semaphore for cross-wg sync
// // we let 1) start from offset:0, 2) start from the end of 1)
// // NOTE: offset is in unit of byte
// return get_total_acc_buffers() *
// }
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv& eqav_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
uint32_t quo_, rem_;
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_eqav_tiles_ + rem_;
// return tile_idx_ / eqav_tiles_ * max_eqav_tiles_ + (tile_idx_ % eqav_tiles_);
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
// printf("reverse tile:%u, %u/%u\n", tile_idx_little_reverse, touched_sk_blocks,
// current_intersec);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
// uint32_t touched_tiles = (block_idx_ * iters_per_big_sk_block + iters - 1) / iters;
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
// uint32_t touched_tiles = (block_idx_little_reverse * iters_per_little_sk_block +
// iters - 1) / iters;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
}; };
} // namespace ck } // namespace ck
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace ck { namespace ck {
...@@ -28,6 +29,7 @@ __global__ void ...@@ -28,6 +29,7 @@ __global__ void
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid, const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid, typename GridwiseGemm::FloatC* p_c_grid,
void* p_workspace,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -45,6 +47,7 @@ __global__ void ...@@ -45,6 +47,7 @@ __global__ void
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_workspace,
M, M,
N, N,
K, K,
...@@ -52,18 +55,26 @@ __global__ void ...@@ -52,18 +55,26 @@ __global__ void
StrideB, StrideB,
StrideC, StrideC,
block_mapping, block_mapping,
static_cast<void*>(p_shared)); static_cast<void*>(p_shared));
#else #else
ignore = karg; ignore = p_a_grid;
ignore = b2c_map; ignore = p_b_grid;
ignore = p_c_grid;
ignore = p_workspace;
ignore = M;
ignore = N;
ignore = K;
ignore = StrideA;
ignore = StrideB;
ignore = StrideC;
ignore = block_mapping;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <index_t BlockSize, template <index_t BlockSize,
typename Block2CTileMap_, typename Block2CTileMap_,
typename FloatAB_, typename FloatAB_,
typename FloatAcc, typename FloatAcc_,
typename FloatC_, typename FloatC_,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
...@@ -117,6 +128,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -117,6 +128,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static constexpr auto KPerBlock = K0PerBlock * K1; static constexpr auto KPerBlock = K0PerBlock * K1;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using FloatAcc = FloatAcc_;
using FloatCShuffle = FloatAcc; using FloatCShuffle = FloatAcc;
using Block2CTileMap = Block2CTileMap_; using Block2CTileMap = Block2CTileMap_;
...@@ -292,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -292,7 +304,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto c_block_size = constexpr auto c_block_size =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize();
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB), sizeof(FloatAB),
...@@ -372,7 +384,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -372,7 +384,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
{ {
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
...@@ -384,11 +396,54 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -384,11 +396,54 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{})); Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
} }
__host__ __device__ static constexpr auto
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL);
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat / CShuffleMRepeatPerShuffle>{},
Number<CShuffleMRepeatPerShuffle * MWave * MPerXDL>{},
Number<NRepeat / CShuffleNRepeatPerShuffle>{},
Number<CShuffleNRepeatPerShuffle * NWave * NPerXDL>{}));
}
__host__ __device__ static constexpr auto GetClusterLengthReduction()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
constexpr auto NPerBlockReduction =
NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
constexpr auto MPerBlockReduction =
(BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
return Sequence<MPerBlockReduction, NPerBlockReduction>{};
}
__host__ __device__ static constexpr auto GetPartialAccBlockDescriptor()
{
const auto c_partial_acc_block_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(NPerBlock, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock),
make_tuple(I1, MPerBlock));
}
}();
return c_partial_acc_block_m_n;
}
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>; using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
__device__ static void Run(const FloatAB* p_a_grid, __device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
void* p_workspace,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -425,6 +480,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -425,6 +480,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// ignore = p_workspace; // TODO: for reduction
// lds max alignment // lds max alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -468,16 +525,187 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -468,16 +525,187 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
uint32_t block_idx = block_mapping.get_block_idx(); uint32_t block_idx = block_mapping.get_block_idx();
bool is_sk_block = block_idx < block_mapping.sk_num_blocks; bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx; bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
block_idx < block_mapping.reduction_start_block_idx;
bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
block_idx < block_mapping.dp_start_block_idx;
uint32_t iter_start, iter_end; uint32_t iter_start, iter_end;
block_mapping.get_block_itr(block_idx, iter_start, iter_end); block_mapping.get_block_itr(block_idx, iter_start, iter_end);
uint32_t total_iter_length = iter_end - iter_start; uint32_t total_iter_length = iter_end - iter_start;
// if(threadIdx.x == 0) // if(threadIdx.x == 0)
// printf("xxx bid:%d, is_sk_block:%d, is_dp_block:%d\n", static_cast<int>(blockIdx.x), // printf("xxx bid:%d, is_sk_block:%d, is_dp_block:%d\n", static_cast<int>(blockIdx.x),
// is_sk_block, is_dp_block); // is_sk_block, is_dp_block);
if(!is_sk_block && !is_dp_block) if(is_padding_block)
return; return;
uint32_t* p_semaphore =
reinterpret_cast<uint32_t*>(reinterpret_cast<char*>(p_workspace) +
block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc)));
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
if(is_reduction_block)
{
// descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto MReduceIters =
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
constexpr auto NReduceIters = math::integer_divide_ceil(
Number<NPerBlock>{},
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto acc_thread_buf_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
constexpr auto partial_acc_load_step_n = make_multi_index(
0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_n_reverse =
make_multi_index(0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_load_step_m =
make_multi_index(cluster_length_reduce.At(I0), 0);
constexpr auto partial_acc_store_step_n = make_multi_index(
0,
0,
0,
cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_n_reverse =
make_multi_index(0,
0,
0,
-1 * (MReduceIters - 1) * cluster_length_reduce.At(I1) *
CBlockTransferScalarPerVector_NWaveNPerXDL);
constexpr auto partial_acc_store_step_m =
make_multi_index(0, cluster_length_reduce.At(I0), 0, 0);
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL,
true>
parcial_acc_buf;
StaticBuffer<AddressSpaceEnum::Vgpr,
FloatAcc,
CBlockTransferScalarPerVector_NWaveNPerXDL,
true>
acc_buf;
acc_buf.Clear();
// start to compute
auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
workgroup_barrier wg_barrier(p_semaphore);
uint32_t tile_acc_offset_start =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
uint32_t tile_acc_offset_end =
block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
auto acc_load = ThreadwiseTensorSliceTransfer_v2<
FloatAcc, // SrcData,
FloatAcc, // DstData,
decltype(c_partial_acc_block_m_n), // SrcDesc,
decltype(acc_thread_buf_desc), // DstDesc,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder,
2, // SrcVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n,
make_multi_index(static_cast<index_t>(tile_acc_offset_start), I0, I0)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, // SrcData,
FloatC, // DstData,
decltype(acc_thread_buf_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation,
Sequence<CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<I0>, // DimAccessOrder,
2, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(spatial_idx[I0], I0, spatial_idx[I1], I0),
CElementwiseOperation{}};
// block synchronization
wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
using Accumulation = ck::detail::
AccumulateWithNanCheck<false /*PropagateNan*/, reduce::Add, FloatAcc>;
static_for<0, MReduceIters, 1>{}([&](auto i_m_reduce) {
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{
auto c_partial_acc_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
static_cast<FloatAcc*>(p_workspace) + i,
c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n,
c_partial_acc_buf,
acc_thread_buf_desc,
make_multi_index(I0),
parcial_acc_buf);
static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}(
[&](auto i_vec) {
constexpr auto offset =
acc_thread_buf_desc.CalculateOffset(make_tuple(i_vec));
Accumulation::Calculate(acc_buf(Number<offset>{}),
parcial_acc_buf[Number<offset>{}]);
});
}
acc_store.Run(acc_thread_buf_desc,
make_multi_index(I0),
acc_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
if constexpr(i_n_reduce != (NReduceIters - 1))
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_load_step_n);
}
else
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_n_reverse);
acc_store.MoveDstSliceWindow(
c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_n_reverse);
}
});
if constexpr(i_m_reduce != MReduceIters - 1)
{
acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
partial_acc_load_step_m);
acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
partial_acc_store_step_m);
}
});
return;
}
}
// offset for last acc buffer of this block
uint32_t block_acc_offset =
(block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
NPerBlock;
while(true) while(true)
{ {
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
...@@ -602,15 +830,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -602,15 +830,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle();
constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared_block), static_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
auto c_partial_acc_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
static_cast<FloatAcc*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mperblock_nblock_nperblock, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_tuple(make_freeze_transform(I0), // freeze mblock make_tuple(make_freeze_transform(I0), // freeze mblock
make_unmerge_transform( make_unmerge_transform(
make_tuple(CShuffleMRepeatPerShuffle, make_tuple(CShuffleMRepeatPerShuffle,
...@@ -701,14 +936,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -701,14 +936,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim, 3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mperblock_nblock_nperblock, {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
...@@ -717,6 +952,32 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -717,6 +952,32 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
0), 0),
c_element_op}; c_element_op};
// LDS to global partial acc
auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2<
ThisThreadBlock, // index_t BlockSize,
CElementwiseOperation, // ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence<1,
CShuffleMRepeatPerShuffle * MWave * MPerXDL,
1,
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData,
FloatC, // typename DstData,
decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
make_multi_index(0, 0, 0, 0),
c_element_op};
constexpr auto mxdlperwave_forward_step = constexpr auto mxdlperwave_forward_step =
make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0);
constexpr auto nxdlperwave_forward_step = constexpr auto nxdlperwave_forward_step =
...@@ -757,19 +1018,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -757,19 +1018,42 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
c_block_copy_lds_to_global.template Run<decltype(c_block_buf), c_block_copy_lds_to_global.template Run<decltype(c_block_buf),
decltype(c_grid_buf), decltype(c_grid_buf),
InMemoryDataOperationEnum::Set>( InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mperblock_nblock_nperblock, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf, c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf); c_grid_buf);
else if(is_sk_block) else if(is_sk_block)
c_block_copy_lds_to_global {
.template Run<decltype(c_block_buf), if constexpr(Block2CTileMap::ReductionStrategy ==
decltype(c_grid_buf), StreamKReductionStrategy::Reduction)
InMemoryDataOperationEnum::AtomicAdd>( {
c_block_desc_mblock_mperblock_nblock_nperblock, // constexpr offset
c_block_buf, c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
c_grid_desc_mblock_mperblock_nblock_nperblock, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_grid_buf); make_tuple(mxdlperwave, I0, nxdlperwave, I0));
c_block_copy_lds_to_partial_acc
.template Run<decltype(c_block_buf),
decltype(c_block_buf),
InMemoryDataOperationEnum::Set>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
c_partial_acc_buf);
}
else if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
c_block_copy_lds_to_global
.template Run<decltype(c_block_buf),
decltype(c_grid_buf),
InMemoryDataOperationEnum::AtomicAdd>(
c_block_desc_mblock_mpershuffle_nblock_npershuffle,
c_block_buf,
c_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_buf);
}
}
// move on nxdlperwave dimension // move on nxdlperwave dimension
if constexpr(nxdlperwave_forward_sweep && if constexpr(nxdlperwave_forward_sweep &&
...@@ -795,6 +1079,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -795,6 +1079,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
mxdlperwave_forward_step); mxdlperwave_forward_step);
} }
}); });
if constexpr(Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
if(is_sk_block)
{
// increase the counter for this tile
workgroup_barrier wg_barrier(p_semaphore);
wg_barrier.inc(tile_idx);
}
}
} }
// exit condition // exit condition
...@@ -802,6 +1097,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -802,6 +1097,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
if(iter_end <= iter_start) if(iter_end <= iter_start)
break; break;
if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction)
{
block_acc_offset -= MPerBlock * NPerBlock;
}
// make sure next loop LDS is ready for use // make sure next loop LDS is ready for use
block_sync_lds(); block_sync_lds();
} }
......
...@@ -178,21 +178,46 @@ struct MDiv ...@@ -178,21 +178,46 @@ struct MDiv
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_); ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
} }
__host__ __device__ uint32_t div(uint32_t dividend) const __host__ __device__ uint32_t div(uint32_t dividend_) const
{ {
return MagicDivision::DoMagicDivision(dividend, multiplier, shift); return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
} }
__host__ __device__ void __host__ __device__ void
divmod(uint32_t dividend, uint32_t& quotient, uint32_t& remainder) const divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{ {
quotient = div(dividend); quotient_ = div(dividend_);
remainder = dividend - (quotient * divisor); remainder_ = dividend_ - (quotient_ * divisor);
} }
__host__ __device__ uint32_t operator/(uint32_t dividend) const { return div(dividend); }
__host__ __device__ uint32_t get() const { return divisor; } __host__ __device__ uint32_t get() const { return divisor; }
}; };
struct MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv2(uint32_t divisor_)
{
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ MDiv2() : multiplier(0), shift(0) {}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck } // namespace ck
...@@ -240,5 +240,21 @@ struct less ...@@ -240,5 +240,21 @@ struct less
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
}; };
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
return Y;
}
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two(Number<X> x)
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
return Number<Y>{};
}
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
namespace ck {
struct workgroup_barrier {
__device__ workgroup_barrier(uint32_t * ptr) :
base_ptr(ptr)
{}
__device__ uint32_t ld(uint32_t offset)
{
#if 0
float d = llvm_amdgcn_raw_buffer_load_fp32(
amdgcn_make_buffer_resource(base_ptr),
0,
offset,
AMDGCN_BUFFER_GLC);
union cvt {
float f32;
uint32_t u32;
};
cvt x;
x.f32 = d;
return x.u32;
#endif
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
}
__device__ void wait_eq(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0){
while(ld(offset) != value){}
}
__syncthreads();
}
__device__ void wait_lt(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0){
while(ld(offset) < value){}
}
__syncthreads();
}
__device__ void wait_set(uint32_t offset, uint32_t compare, uint32_t value)
{
if(threadIdx.x == 0){
while(atomicCAS(base_ptr + offset, compare, value) != compare){}
}
__syncthreads();
}
// enter critical zoon, assume buffer is zero when launch kernel
__device__ void aquire(uint32_t offset)
{
wait_set(offset, 0, 1);
}
// exit critical zoon, assume buffer is zero when launch kernel
__device__ void release(uint32_t offset)
{
wait_set(offset, 1, 0);
}
__device__ void inc(uint32_t offset)
{
__syncthreads();
if(threadIdx.x == 0){
atomicAdd(base_ptr + offset, 1);
}
}
uint32_t * base_ptr;
};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment