Commit f39f7d79 authored by Jing Zhang's avatar Jing Zhang
Browse files

skip b_lds

parent 720280ea
......@@ -14,9 +14,9 @@ template <index_t BlockSize,
typename FloatB,
typename FloatC,
typename ABlockDesc_K0_M_K1,
typename BBlockDesc_K0_N_K1,
typename BThreadDesc_K0_N_K1,
index_t MPerThread,
index_t NPerThread,
index_t NPerBlock,
index_t K0PerLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
......@@ -32,10 +32,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1);
static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2);
static constexpr auto N = BBlockDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto NPerThread = BThreadDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto M0 = M / MPerThread;
static constexpr auto M1 = MPerThread;
static constexpr auto N = NPerBlock;
static constexpr auto N0 = N / NPerThread;
static constexpr auto N1 = NPerThread;
......@@ -51,15 +53,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)},
b_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I2] * NPerThread, 0)}
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)}
{
static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() &&
BBlockDesc_K0_N_K1::IsKnownAtCompileTime(),
BThreadDesc_K0_N_K1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BBlockDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BBlockDesc_K0_N_K1{}.GetLength(I2),
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BThreadDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BThreadDesc_K0_N_K1{}.GetLength(I2),
"wrong! E dimension not consistent\n");
static_assert(K0 % K0PerLoop == 0, "");
......@@ -90,27 +91,23 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
return c_m0_m1_n0_n1_thread_cluster_idx;
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBlockBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{};
constexpr auto b_block_mtx = BBlockDesc_K0_N_K1{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, FloatB, b_thread_mtx_.GetElementSpaceSize(), true>
b_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
......@@ -126,17 +123,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
make_tuple(I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_mtx,
make_tuple(k0_begin, I0, I0),
b_block_buf,
b_thread_mtx_,
make_tuple(I0, I0, I0),
b_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
make_tuple(k0_begin, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
......@@ -153,20 +143,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
K1,
K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
BBlockDesc_K0_N_K1,
decltype(b_thread_mtx_),
Sequence<K0PerLoop, NPerThread, K1>,
Sequence<0, 1, 2>,
2,
K1,
K1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
......
......@@ -120,20 +120,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
......@@ -397,6 +389,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
ignore = b_global_buf;
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
......@@ -425,26 +418,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
......@@ -471,45 +451,36 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
a_block_desc_k0_m0_m1_k1,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, 0, 0),
b_block_desc_k0_n0_n1_k1,
make_multi_index(0, 0, 0, 0));
static constexpr auto b_thread_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_thread_desc_k0_n0_n1_k1),
Sequence<K0PerBlock, 1, NPerThread, K1.value>,
Sequence<0, 1, 2, 3>, // BBlockTransferSrcAccessOrder,
3,
K1,
1,
false,
true>(
b_grid_desc_k0_n0_n1_k1,
make_multi_index(0, in0, get_thread_local_1d_id() * NPerThread, 0));
static constexpr auto b_k0_n_k1_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerBlock>{}, Number<NPerThread>{}, Number<K1>{}));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
MPerThread,
NPerThread,
NPerBlock,
KPerThread>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
......@@ -522,11 +493,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
......@@ -535,28 +508,26 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// Initialize C
c_thread_buf.Clear();
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf);
}
if constexpr(HasMainKBlockLoop)
......@@ -572,40 +543,50 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf);
block_sync_lds();
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
......@@ -615,32 +596,37 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_thread_slice_copy_step);
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
b_thread_desc_k0_n0_n1_k1,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
block_sync_lds();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment