Commit 888f1d68 authored by Chao Liu's avatar Chao Liu
Browse files

replace raw pointer with DynamicBuffer in blockwise and threadwise gemm

parent 35d68cf8
......@@ -503,10 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename CThreadBuffer>
__device__ void Run_pipelined_2x2(const FloatA* p_a_block,
const FloatB* p_b_block,
CThreadBuffer c_thread_buf) const
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -548,8 +548,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA p_a_thread[a_thread_mtx_desc_.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx_desc_.GetElementSpaceSize()];
auto a_thread_buf = make_dynamic_buffer<FloatA>(p_a_thread);
auto b_thread_buf = make_dynamic_buffer<FloatB>(p_b_thread);
auto a_thread_buf = make_dynamic_buffer(p_a_thread);
auto b_thread_buf = make_dynamic_buffer(p_b_thread);
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
FloatB,
......@@ -561,7 +561,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<0>{}),
p_a_block,
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
a_thread_buf);
......@@ -569,7 +569,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_0
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<0>{}),
p_b_block,
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf);
......@@ -577,7 +577,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_1
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(Number<0>{}, Number<NPerLevel1Cluster>{}),
p_b_block,
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
b_thread_buf);
......@@ -585,7 +585,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_1
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(Number<0>{}, Number<MPerLevel1Cluster>{}),
p_a_block,
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
a_thread_buf);
......@@ -611,7 +611,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<0>{}),
p_a_block,
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
a_thread_buf);
......@@ -627,7 +627,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_0
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<0>{}),
p_b_block,
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<0>{}),
b_thread_buf);
......@@ -643,7 +643,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_1
b_thread_copy_.Run(BlockMatrixB{},
make_tuple(k, Number<NPerLevel1Cluster>{}),
p_b_block,
b_block_buf,
b_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<NPerThreadSubC>{}),
b_thread_buf);
......@@ -651,7 +651,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_1
a_thread_copy_.Run(BlockMatrixA{},
make_tuple(k, Number<MPerLevel1Cluster>{}),
p_a_block,
a_block_buf,
a_thread_mtx_desc_,
make_tuple(Number<0>{}, Number<MPerThreadSubC>{}),
a_thread_buf);
......@@ -690,9 +690,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
}
template <typename CThreadBuffer>
__device__ void
Run(const FloatA* p_a_block, const FloatB* p_b_block, CThreadBuffer c_thread_buf) const
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
......@@ -706,14 +707,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, c_thread_buf);
Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf);
}
else
{
Run_naive(p_a_block, p_b_block, c_thread_buf);
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
}
#else
Run_naive(p_a_block, p_b_block, c_thread_buf);
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
#endif
}
};
......
......@@ -751,6 +751,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
......@@ -762,12 +774,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
if constexpr(HasMainKBlockLoop)
{
FloatAB* p_a_block_even = p_a_block_double;
FloatAB* p_b_block_even = p_b_block_double;
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
index_t k_block_data_begin = 0;
// LDS double buffer: main body
......@@ -791,7 +797,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_even, p_b_block_even, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
......@@ -814,7 +820,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, c_thread_buf);
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
......@@ -841,7 +847,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
......@@ -850,16 +856,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size,
c_thread_buf);
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
......
......@@ -1321,12 +1321,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// Assume:
// 1. src:
// 1. src_desc is known at compile-time
// 2. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// compile-time distance to src_reference_idx
// 3. use #-iterator
// 4. use #-iterator
// 2. dst:
// 1. dst_desc is known at compile-time
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate)
......@@ -1364,10 +1367,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
template <typename SrcRefToOriginDisplacement,
typename DstRefToOriginDisplacement,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcData* p_src,
const SrcBuffer& src_buf,
const DstDesc&,
const DstRefToOriginDisplacement&,
DstBuffer& dst_buf) const
......@@ -1375,6 +1379,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
#if 0 // debug
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
#endif
static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<
......@@ -1462,8 +1470,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc, src_data_coord);
src_tmp_buf.template AsType<src_vector_t>()(Number<0>{}) =
is_src_valid
? *reinterpret_cast<const src_vector_t*>(&p_src[src_data_coord.GetOffset()])
is_src_valid ? src_buf.template AsType<src_vector_t>()[src_data_coord.GetOffset()]
: src_vector_t{0};
// copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData)
......
......@@ -15,6 +15,10 @@ struct StaticBuffer : public vector_type<ScalarType, N>
using base = vector_type<ScalarType, N>;
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <typename T, index_t N>
......@@ -65,6 +69,10 @@ struct DynamicBuffer
{
return PointerWrapper<X>{reinterpret_cast<X*>(p_scalar_)};
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <typename T>
......
......@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment