Commit 7484a103 authored by Chao Liu's avatar Chao Liu
Browse files

update v5r1

parent ba31eb3e
...@@ -6,11 +6,6 @@ ...@@ -6,11 +6,6 @@
namespace ck { namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize, template <index_t BlockSize,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -58,10 +53,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -58,10 +53,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
AddressSpace::Vgpr, AddressSpace::Vgpr,
1>; 1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
...@@ -183,6 +174,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3 ...@@ -183,6 +174,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
}); });
}); });
} }
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
}
private:
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
}; };
} // namespace ck } // namespace ck
......
...@@ -145,8 +145,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -145,8 +145,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{})); Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm = auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -226,6 +225,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -226,6 +225,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
auto a_block_buf = make_dynamic_buffer(p_a_block);
// register allocation for output // register allocation for output
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf; StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf;
...@@ -268,10 +269,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -268,10 +269,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__syncthreads(); __syncthreads();
index_t b_block_data_begin = 0;
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
index_t e_block_data_begin = 0;
// LDS double buffer: main body // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow // use Do-While loop instead of For loop to simplify control flow
do do
...@@ -289,13 +290,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -289,13 +290,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window // TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
b_block_data_begin += EPerBlock; blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step); b_thread_slice_copy_step);
...@@ -308,15 +305,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -308,15 +305,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_odd_buf,
c_thread_buf);
b_block_data_begin += EPerBlock; blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
} while(b_block_data_begin < E - 2 * EPerBlock); e_block_data_begin += 2 * EPerBlock;
} while(e_block_data_begin < E - 2 * EPerBlock);
} }
// LDS double buffer: tail // LDS double buffer: tail
...@@ -333,29 +328,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -333,29 +328,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_e_n_ho_wo_global_iterator_hacks); b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
b_block_data_begin += EPerBlock; blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_odd_buf,
c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run( blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
} }
// output: register to global memory // output: register to global memory
......
...@@ -1338,16 +1338,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -1338,16 +1338,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// 1. src: // 1. src:
// 1. SrcDesc is known at compile-time // 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer // 2. SrcBuffer is DynamicBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a // 3. src_ref_idx is known at run-time
// compile-time distance to src_reference_idx // 4. SrcRefToOriginDisplacement is known at compile-time
// 4. use #-iterator // 5. use #-iterator
// 2. dst: // 2. dst:
// 1. DstDesc is known at compile-time // 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer // 2. DstBuffer is StaticBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a // 3. DstOriginIdx is known at compile-time
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a // 4. use direct address calculation
// compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate)
// 3. vector access on src // 3. vector access on src
template < template <
typename SrcData, typename SrcData,
...@@ -1381,14 +1379,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1381,14 +1379,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
typename DstRefToOriginDisplacement, typename DstOriginIdx,
typename SrcBuffer, typename SrcBuffer,
typename DstBuffer> typename DstBuffer>
__device__ void Run(const SrcDesc&, __device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&, const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstRefToOriginDisplacement&, const DstOriginIdx&,
DstBuffer& dst_buf) const DstBuffer& dst_buf) const
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
...@@ -1402,10 +1400,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1402,10 +1400,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time< static_assert(
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time< is_known_at_compile_time<
remove_cv_t<remove_reference_t<DstRefToOriginDisplacement>>>::value, remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"); "at compile-time");
...@@ -1415,7 +1413,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1415,7 +1413,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_ref_to_origin_disp_idx = to_multi_index(DstRefToOriginDisplacement{}); constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -1505,14 +1503,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 ...@@ -1505,14 +1503,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// copy data from dst_tmp_vector into dst_buf // copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i]; dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
}); });
}); });
} }
template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{
constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
src_desc, to_multi_index(src_slice_move_step_idx));
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
}
private: private:
SrcCoord src_ref_coord_; SrcCoord src_ref_coord_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment