Commit ca8b5c79 authored by carlushuang's avatar carlushuang
Browse files

update reduction for streamk(not ready yet)

parent b2a49620
......@@ -111,6 +111,16 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
threadwise_transfer_.SetSrcSliceOrigin(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
threadwise_transfer_.SetDstSliceOrigin(dst_desc, dst_slice_origin_idx);
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
......
......@@ -141,7 +141,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
const auto kernel = kernel_gemm_xdlops_streamk<GridwiseGemm>;
// TODO: remove clear buffer for streamk kernels
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Atomic)
{
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
}
else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
char* workspace_semaphore = reinterpret_cast<char*>(karg.p_workspace_);
workspace_semaphore =
workspace_semaphore +
karg.block_mapping.get_workspace_size_for_acc(sizeof(GridwiseGemm::FloatAcc));
hipGetErrorString(hipMemset(
workspace_semaphore, 0, karg.block_mapping.get_workspace_size_for_semaphore()));
}
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -151,6 +165,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.p_workspace_,
karg.M,
karg.N,
karg.K,
......@@ -170,6 +185,27 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
}
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* p_arg = dynamic_cast<const Argument*>(pArg);
if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
return p_arg->block_mapping.get_workspace_size(sizeof(GridwiseGemm::FloatAcc));
}
else
{
return 0;
}
}
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
}
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
......
......@@ -637,44 +637,53 @@ struct BlockToCTileMap_3DGrid_KSplit
}
};
enum StreamKReductionStrategy
{
Atomic = 0, // sk block use atomic to do reduction
Reduction, // let some workgroup responsible for doing the reduction operation
};
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
uint32_t TileSwizzleSubM_ = 8>
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8>
struct BlockToCTileMap_GemmStreamK
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t sk_total_iters;
// uint32_t sk_total_iters;
uint32_t dp_start_block_idx;
uint32_t dp_iters_per_block;
uint32_t dp_num_blocks;
// uint32_t dp_iters_per_block;
// uint32_t dp_num_blocks;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block;
// uint32_t tiles_cover_big_blocks; // for reduction
// uint32_t total_acc_buffers; // for reduction
MDiv2 n_tiles;
MDiv k_iters_per_tile;
MDiv n_tiles;
MDiv eqav_tiles_big; // for reduction
MDiv eqav_tiles_little; // for reduction
// MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
static int env_get_int(const char* var_name, int default_int)
{
char* v = getenv(var_name);
int r = default_int;
if(v)
r = atoi(v);
return r;
}
// prefer construct on host
BlockToCTileMap_GemmStreamK(uint32_t m,
uint32_t n,
......@@ -727,8 +736,9 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles = partial_dispatche_tiles + num_cu;
}
dp_iters_per_block = k_iters_per_tile.get();
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_iters_per_block = k_iters_per_tile.get();
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
uint32_t dp_num_blocks = 0;
{
uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
......@@ -775,7 +785,6 @@ struct BlockToCTileMap_GemmStreamK
// give a chance to control num of sk blocks
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
sk_num_blocks = env_get_int("sk_num_blocks", sk_num_blocks);
if(sk_num_blocks == 0)
{
......@@ -807,7 +816,16 @@ struct BlockToCTileMap_GemmStreamK
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
}
}
n_tiles = MDiv(math::integer_divide_ceil(n, NPerBlock));
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
// tile_swizzle_sub_m_rem =
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
......@@ -831,9 +849,28 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_big_block);
}
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ dim3 get_grid_dims() const
{
return dim3(dp_start_block_idx + dp_num_blocks, 1, 1);
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
}
else
return dim3(reduction_start_block_idx, 1, 1);
}
__device__ uint32_t get_block_idx() const
......@@ -858,6 +895,8 @@ struct BlockToCTileMap_GemmStreamK
}
else if(block_idx >= dp_start_block_idx)
{
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
......@@ -882,10 +921,11 @@ struct BlockToCTileMap_GemmStreamK
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t /*n*/) const
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{
uint32_t m_tile_idx, n_tile_idx;
n_tiles.divmod(tile_idx, m_tile_idx, n_tile_idx);
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// return make_tuple(m_tile_idx, n_tile_idx);
// swizzle tile
......@@ -901,7 +941,7 @@ struct BlockToCTileMap_GemmStreamK
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles.get();
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
......@@ -911,6 +951,115 @@ struct BlockToCTileMap_GemmStreamK
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
// __host__ __device__ uint32_t get_workspace_offset_for_semaphore() const
// {
// // workspace contains 2 part, 1) partial reduction buffer 2) semaphore for cross-wg sync
// // we let 1) start from offset:0, 2) start from the end of 1)
// // NOTE: offset is in unit of byte
// return get_total_acc_buffers() *
// }
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__device__ uint32_t get_tile_intersections(uint32_t tiles_, const MDiv& eqav_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1;
uint32_t quo_, rem_;
eqav_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_eqav_tiles_ + rem_;
// return tile_idx_ / eqav_tiles_ * max_eqav_tiles_ + (tile_idx_ % eqav_tiles_);
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little);
// printf("reverse tile:%u, %u/%u\n", tile_idx_little_reverse, touched_sk_blocks,
// current_intersec);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
// uint32_t touched_tiles = (block_idx_ * iters_per_big_sk_block + iters - 1) / iters;
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
// uint32_t touched_tiles = (block_idx_little_reverse * iters_per_little_sk_block +
// iters - 1) / iters;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
};
} // namespace ck
......@@ -178,21 +178,46 @@ struct MDiv
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ uint32_t div(uint32_t dividend) const
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend, multiplier, shift);
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend, uint32_t& quotient, uint32_t& remainder) const
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient = div(dividend);
remainder = dividend - (quotient * divisor);
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
__host__ __device__ uint32_t operator/(uint32_t dividend) const { return div(dividend); }
__host__ __device__ uint32_t get() const { return divisor; }
};
struct MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv2(uint32_t divisor_)
{
ck::tie(multiplier, shift) = MagicDivision::CalculateMagicNumbers(divisor_);
}
__host__ __device__ MDiv2() : multiplier(0), shift(0) {}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck
......@@ -240,5 +240,21 @@ struct less
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1));
return Y;
}
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two(Number<X> x)
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1));
return Number<Y>{};
}
} // namespace math
} // namespace ck
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
namespace ck {
struct workgroup_barrier {
__device__ workgroup_barrier(uint32_t * ptr) :
base_ptr(ptr)
{}
__device__ uint32_t ld(uint32_t offset)
{
#if 0
float d = llvm_amdgcn_raw_buffer_load_fp32(
amdgcn_make_buffer_resource(base_ptr),
0,
offset,
AMDGCN_BUFFER_GLC);
union cvt {
float f32;
uint32_t u32;
};
cvt x;
x.f32 = d;
return x.u32;
#endif
return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED);
}
__device__ void wait_eq(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0){
while(ld(offset) != value){}
}
__syncthreads();
}
__device__ void wait_lt(uint32_t offset, uint32_t value)
{
if(threadIdx.x == 0){
while(ld(offset) < value){}
}
__syncthreads();
}
__device__ void wait_set(uint32_t offset, uint32_t compare, uint32_t value)
{
if(threadIdx.x == 0){
while(atomicCAS(base_ptr + offset, compare, value) != compare){}
}
__syncthreads();
}
// enter critical zoon, assume buffer is zero when launch kernel
__device__ void aquire(uint32_t offset)
{
wait_set(offset, 0, 1);
}
// exit critical zoon, assume buffer is zero when launch kernel
__device__ void release(uint32_t offset)
{
wait_set(offset, 1, 0);
}
__device__ void inc(uint32_t offset)
{
__syncthreads();
if(threadIdx.x == 0){
atomicAdd(base_ptr + offset, 1);
}
}
uint32_t * base_ptr;
};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment