Commit ecef4987 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Extract reading A tile logic out of Run() method

parent 3943aab3
......@@ -306,17 +306,12 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
#if defined(EXTRACT_DS_READ)
static_assert(MRepeat == 1);
#endif // defined(EXTRACT_DS_READ)
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
#if defined(EXTRACT_DS_READ)
static_assert(MRepeat == 1);
template <typename ABlockBuffer>
__device__ void PrepareRun(const ABlockBuffer& a_block_buf) const {
Number<0> m0;
// read A
......@@ -326,6 +321,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
}
#endif // defined(EXTRACT_DS_READ)
#if defined(EXTRACT_DS_READ)
template <typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
Number<0> m0;
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
......@@ -360,8 +366,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
});
});
#else
static_assert(false);
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
......
......@@ -84,7 +84,8 @@ struct GridwiseGemmPipeline_v2
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.PrepareRun(a_block_buf);
blockwise_gemm.Run(b_block_buf, c_thread_buf);
block_sync_lds();
......@@ -111,7 +112,8 @@ struct GridwiseGemmPipeline_v2
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.PrepareRun(a_block_buf);
blockwise_gemm.Run(b_block_buf, c_thread_buf);
block_sync_lds();
......@@ -122,7 +124,8 @@ struct GridwiseGemmPipeline_v2
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
blockwise_gemm.PrepareRun(a_block_buf);
blockwise_gemm.Run(b_block_buf, c_thread_buf);
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment