Commit e6f5a78b authored by coderfeli's avatar coderfeli
Browse files

add double buffer scratch

parent 3784329b
......@@ -300,21 +300,17 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<0>{});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// // Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
......@@ -351,10 +347,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
......@@ -364,14 +359,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>>();
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
// b_thread_vec.template AsType<ComputeDataType>()(ik) =
// b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
......@@ -400,20 +392,65 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
// make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
// b_block_buf,
// b_thread_desc_,
// make_tuple(n0, I0, k0, I0),
// b_thread_buf);
// });
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<0>{});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 1));
}
// tail
......@@ -424,14 +461,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>>();
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
// b_thread_vec.template AsType<ComputeDataType>()(ik) =
// b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
......
......@@ -109,10 +109,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1
}
}
template <typename SeqIdx>
__device__ auto GetSrcThreadScratchIdx()
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx()
{
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx>();
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
......
......@@ -1394,7 +1394,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
2>(
b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid, 0, 0),
b_element_op,
......
......@@ -268,7 +268,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ auto GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
__device__ constexpr auto GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment