Commit 888f1d68 authored by Chao Liu's avatar Chao Liu
Browse files

replace raw pointer with DynamicBuffer in blockwise and threadwise gemm

parent 35d68cf8
...@@ -503,10 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -503,10 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
template <typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run_pipelined_2x2(const FloatA* p_a_block, __device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf,
const FloatB* p_b_block, const BBlockBuffer& b_block_buf,
CThreadBuffer c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -548,8 +548,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -548,8 +548,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()]; FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()]; FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
auto a_thread_buf = make_dynamic_buffer<FloatA>(p_a_thread); auto a_thread_buf = make_dynamic_buffer(p_a_thread);
auto b_thread_buf = make_dynamic_buffer<FloatB>(p_b_thread); auto b_thread_buf = make_dynamic_buffer(p_b_thread);
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB, FloatB,
...@@ -561,7 +561,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -561,7 +561,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<0>{}), make_tuple(Number<0>{}, Number<0>{}),
p_a_block, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(Number<0>{}, Number<0>{}),
a_thread_buf); a_thread_buf);
...@@ -569,7 +569,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -569,7 +569,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<0>{}), make_tuple(Number<0>{}, Number<0>{}),
p_b_block, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf); b_thread_buf);
...@@ -577,7 +577,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -577,7 +577,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}), make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}),
p_b_block, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
b_thread_buf); b_thread_buf);
...@@ -585,7 +585,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -585,7 +585,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}), make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}),
p_a_block, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
a_thread_buf); a_thread_buf);
...@@ -611,7 +611,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -611,7 +611,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0 // read A_sub_0
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<0>{}), make_tuple(k, Number<0>{}),
p_a_block, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(Number<0>{}, Number<0>{}),
a_thread_buf); a_thread_buf);
...@@ -627,7 +627,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -627,7 +627,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_0 // read B_sub_0
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<0>{}), make_tuple(k, Number<0>{}),
p_b_block, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}), make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf); b_thread_buf);
...@@ -643,7 +643,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -643,7 +643,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_1 // read B_sub_1
b_thread_copy_.Run(BlockMatrixB{}, b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<NPerLevel1Cluster>{}), make_tuple(k, Number<NPerLevel1Cluster>{}),
p_b_block, b_block_buf,
b_thread_mtx_desc_, b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}), make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
b_thread_buf); b_thread_buf);
...@@ -651,7 +651,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -651,7 +651,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_1 // read A_sub_1
a_thread_copy_.Run(BlockMatrixA{}, a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<MPerLevel1Cluster>{}), make_tuple(k, Number<MPerLevel1Cluster>{}),
p_a_block, a_block_buf,
a_thread_mtx_desc_, a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}), make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
a_thread_buf); a_thread_buf);
...@@ -690,9 +690,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -690,9 +690,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{})); make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
} }
template <typename CThreadBuffer> template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void __device__ void Run(const ABlockBuffer& a_block_buf,
Run(const FloatA* p_a_block, const FloatB* p_b_block, CThreadBuffer c_thread_buf) const const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{ {
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -706,14 +707,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1 ...@@ -706,14 +707,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
if constexpr(MRepeat == 2 && NRepeat == 2) if constexpr(MRepeat == 2 && NRepeat == 2)
{ {
Run_pipelined_2x2(p_a_block, p_b_block, c_thread_buf); Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf);
} }
else else
{ {
Run_naive(p_a_block, p_b_block, c_thread_buf); Run_naive(a_block_buf, b_block_buf, c_thread_buf);
} }
#else #else
Run_naive(p_a_block, p_b_block, c_thread_buf); Run_naive(a_block_buf, b_block_buf, c_thread_buf);
#endif #endif
} }
}; };
......
...@@ -751,6 +751,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -751,6 +751,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
...@@ -762,12 +774,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -762,12 +774,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
// LDS double buffer: main body // LDS double buffer: main body
...@@ -791,7 +797,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -791,7 +797,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
...@@ -814,7 +820,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -814,7 +820,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
...@@ -841,7 +847,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -841,7 +847,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
...@@ -850,16 +856,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -850,16 +856,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
p_b_block_double + b_block_space_size,
c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
......
...@@ -1321,12 +1321,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1321,12 +1321,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// Assume: // Assume:
// 1. src: // 1. src:
// 1. src_desc is known at compile-time // 1. SrcDesc is known at compile-time
// 2. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a // 2. SrcBuffer is DynamicBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// compile-time distance to src_reference_idx // compile-time distance to src_reference_idx
// 3. use #-iterator // 4. use #-iterator
// 2. dst: // 2. dst:
// 1. dst_desc is known at compile-time // 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a // 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx // compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate) // 3. use direct address calculation (lower of coordinate)
...@@ -1364,10 +1367,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1364,10 +1367,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
typename DstRefToOriginDisplacement, typename DstRefToOriginDisplacement,
typename SrcBuffer,
typename DstBuffer> typename DstBuffer>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&, const SrcRefToOriginDisplacement&,
const SrcData* p_src, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstRefToOriginDisplacement&, const DstRefToOriginDisplacement&,
DstBuffer& dst_buf) const DstBuffer& dst_buf) const
...@@ -1375,6 +1379,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1375,6 +1379,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time"); "wrong! SrcDesc and DstDesc need to known at compile-time");
#if 0 // debug
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
#endif
static_assert(is_known_at_compile_time< static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value && remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time< is_known_at_compile_time<
...@@ -1462,9 +1470,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1462,9 +1470,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc, src_data_coord); src_desc, src_data_coord);
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset()]
? *reinterpret_cast<const src_vector_t*>(&p_src[src_data_coord.GetOffset()]) : src_vector_t{0};
: src_vector_t{0};
// copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData) // copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData)
auto dst_tmp_buf = make_static_buffer<DstData>(Number<SrcScalarPerVector>{}); auto dst_tmp_buf = make_static_buffer<DstData>(Number<SrcScalarPerVector>{});
......
...@@ -15,6 +15,10 @@ struct StaticBuffer : public vector_type<ScalarType, N> ...@@ -15,6 +15,10 @@ struct StaticBuffer : public vector_type<ScalarType, N>
using base = vector_type<ScalarType, N>; using base = vector_type<ScalarType, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {} __host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
}; };
template <typename T, index_t N> template <typename T, index_t N>
...@@ -65,6 +69,10 @@ struct DynamicBuffer ...@@ -65,6 +69,10 @@ struct DynamicBuffer
{ {
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)}; return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
}; };
template <typename T> template <typename T>
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 1 #define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment