Commit 7484a103 authored by Chao Liu's avatar Chao Liu
Browse files

update v5r1

parent ba31eb3e
......@@ -6,11 +6,6 @@
namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename FloatA,
typename FloatB,
......@@ -58,10 +53,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
AddressSpace::Vgpr,
1>;
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
......@@ -183,6 +174,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
});
});
}
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
}
private:
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
};
} // namespace ck
......
......@@ -145,20 +145,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
......@@ -226,6 +225,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB* p_a_block = p_shared_block;
auto a_block_buf = make_dynamic_buffer(p_a_block);
// register allocation for output
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf;
......@@ -268,10 +269,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__syncthreads();
index_t b_block_data_begin = 0;
if constexpr(HasMainKBlockLoop)
{
index_t e_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
......@@ -289,13 +290,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
b_block_data_begin += EPerBlock;
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
......@@ -308,15 +305,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_odd_buf,
c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_block_data_begin += EPerBlock;
e_block_data_begin += 2 * EPerBlock;
} while(b_block_data_begin < E - 2 * EPerBlock);
} while(e_block_data_begin < E - 2 * EPerBlock);
}
// LDS double buffer: tail
......@@ -333,29 +328,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
b_e_n_ho_wo_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
b_block_data_begin += EPerBlock;
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_odd_buf,
c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
make_dynamic_buffer(p_a_block + a_e_k_block_desc.CalculateOffset(
make_tuple(b_block_data_begin, 0))),
b_thread_even_buf,
c_thread_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
......
......@@ -1338,16 +1338,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// compile-time distance to src_reference_idx
// 4. use #-iterator
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-iterator
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate)
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 3. vector access on src
template <
typename SrcData,
......@@ -1381,14 +1379,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
}
template <typename SrcRefToOriginDisplacement,
typename DstRefToOriginDisplacement,
typename DstOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstRefToOriginDisplacement&,
const DstOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
......@@ -1402,12 +1400,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<
remove_cv_t<remove_reference_t<DstRefToOriginDisplacement>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
static_assert(
is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
......@@ -1415,7 +1413,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_ref_to_origin_disp_idx = to_multi_index(DstRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -1505,14 +1503,25 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
to_multi_index(dst_ref_to_origin_disp_idx) + data_to_origin_disp_idx +
i * src_scalar_step_in_vector);
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
});
}
template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{
constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
src_desc, to_multi_index(src_slice_move_step_idx));
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
}
private:
SrcCoord src_ref_coord_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment