Commit f39f7d79 authored by Jing Zhang's avatar Jing Zhang
Browse files

skip b_lds

parent 720280ea
...@@ -14,9 +14,9 @@ template <index_t BlockSize, ...@@ -14,9 +14,9 @@ template <index_t BlockSize,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename ABlockDesc_K0_M_K1, typename ABlockDesc_K0_M_K1,
typename BBlockDesc_K0_N_K1, typename BThreadDesc_K0_N_K1,
index_t MPerThread, index_t MPerThread,
index_t NPerThread, index_t NPerBlock,
index_t K0PerLoop> index_t K0PerLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{ {
...@@ -32,10 +32,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -32,10 +32,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1); static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1);
static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2); static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2);
static constexpr auto N = BBlockDesc_K0_N_K1{}.GetLength(I1); static constexpr auto NPerThread = BThreadDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto M0 = M / MPerThread; static constexpr auto M0 = M / MPerThread;
static constexpr auto M1 = MPerThread; static constexpr auto M1 = MPerThread;
static constexpr auto N = NPerBlock;
static constexpr auto N0 = N / NPerThread; static constexpr auto N0 = N / NPerThread;
static constexpr auto N1 = NPerThread; static constexpr auto N1 = NPerThread;
...@@ -51,15 +53,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -51,15 +53,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())}, get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)}, a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)}
b_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I2] * NPerThread, 0)}
{ {
static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() && static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() &&
BBlockDesc_K0_N_K1::IsKnownAtCompileTime(), BThreadDesc_K0_N_K1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BBlockDesc_K0_N_K1{}.GetLength(I0) && static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BThreadDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BBlockDesc_K0_N_K1{}.GetLength(I2), ABlockDesc_K0_M_K1{}.GetLength(I2) == BThreadDesc_K0_N_K1{}.GetLength(I2),
"wrong! E dimension not consistent\n"); "wrong! E dimension not consistent\n");
static_assert(K0 % K0PerLoop == 0, ""); static_assert(K0 % K0PerLoop == 0, "");
...@@ -90,27 +91,23 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -90,27 +91,23 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
return c_m0_m1_n0_n1_thread_cluster_idx; return c_m0_m1_n0_n1_thread_cluster_idx;
} }
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf, __device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf, const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const CThreadBuffer& c_thread_buf) const
{ {
static_assert( static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value && is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBlockBuffer::type>, remove_cvref_t<FloatB>>::value && is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value && is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type"); "wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{}; constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{};
constexpr auto b_block_mtx = BBlockDesc_K0_N_K1{};
// thread A buffer for GEMM // thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true> StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf; a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, FloatB, b_thread_mtx_.GetElementSpaceSize(), true>
b_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA, constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB, FloatB,
FloatC, FloatC,
...@@ -126,17 +123,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -126,17 +123,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
a_thread_buf); a_thread_buf);
b_thread_copy_.Run(b_block_mtx,
make_tuple(k0_begin, I0, I0),
b_block_buf,
b_thread_mtx_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf, threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0), make_tuple(I0, I0, I0),
b_thread_buf, b_thread_buf,
make_tuple(I0, I0, I0), make_tuple(k0_begin, I0, I0),
c_thread_buf, c_thread_buf,
make_tuple(I0, I0, I0, I0)); make_tuple(I0, I0, I0, I0));
}); });
...@@ -153,20 +143,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -153,20 +143,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
K1, K1,
K1>; K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc_K0_N_K1,
decltype(b_thread_mtx_),
Sequence<K0PerLoop, NPerThread, K1>,
Sequence<0, 1, 2>,
2,
K1,
K1>;
CIndex c_thread_origin_data_idx_; CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
}; };
} // namespace ck } // namespace ck
......
...@@ -120,20 +120,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -120,20 +120,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
...@@ -397,6 +389,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -397,6 +389,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
...@@ -425,26 +418,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -425,26 +418,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!"); "wrong!");
// A matrix blockwise copy // A matrix blockwise copy
...@@ -471,45 +451,36 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -471,45 +451,36 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
a_block_desc_k0_m0_m1_k1, a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0)); make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy static constexpr auto b_thread_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< make_tuple(Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
BlockSize,
InMemoryDataOperationEnum::Set, auto b_threadwise_copy =
Sequence<K0PerBlock, 1, NPerBlock, K1.value>, ThreadwiseTensorSliceTransfer_v2<FloatAB,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, FloatAB,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
BBlockTransferThreadClusterArrangeOrder, decltype(b_thread_desc_k0_n0_n1_k1),
FloatAB, Sequence<K0PerBlock, 1, NPerThread, K1.value>,
FloatAB, Sequence<0, 1, 2, 3>, // BBlockTransferSrcAccessOrder,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>, 3,
decltype(b_block_desc_k0_n0_n1_k1), K1,
BBlockTransferSrcAccessOrder, 1,
Sequence<0, 1, 2, 3>, false,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths true>(
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths b_grid_desc_k0_n0_n1_k1,
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder make_multi_index(0, in0, get_thread_local_1d_id() * NPerThread, 0));
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false, static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
true>(b_grid_desc_k0_n0_n1_k1, make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
make_multi_index(0, in0, 0, 0),
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize, BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerBlock,
KPerThread>{}; KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
...@@ -522,11 +493,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -522,11 +493,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_aligned_space_size = math::integer_least_multiple( constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
...@@ -535,28 +508,26 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -535,28 +508,26 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size, p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf);
} }
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -572,40 +543,50 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -572,40 +543,50 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf);
block_sync_lds(); block_sync_lds();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step); a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step); b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds(); block_sync_lds();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock);
...@@ -615,32 +596,37 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -615,32 +596,37 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_thread_slice_copy_step);
block_sync_lds(); block_sync_lds();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds(); block_sync_lds();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
} }
// output: register to global memory // output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment