"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "a29d2b76757eadf83e0d7c902b0fdbcb2610cf02"
Commit 84a1dfb0 authored by aska-0096's avatar aska-0096
Browse files

1. better lds pipeline.

2. Too early s_barrier().
parent 8e8c6ea1
...@@ -39,7 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl ...@@ -39,7 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on // clang-format on
using DeviceGemmInstance = DeviceGemmInstance1; using DeviceGemmInstance = DeviceGemmInstance1;
......
...@@ -383,12 +383,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -383,12 +383,6 @@ struct BlockwiseGemmXdlops_pipeline_v1
do do
{ {
__builtin_amdgcn_sched_group_barrier(0x020, 4, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 4, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
...@@ -449,27 +443,9 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -449,27 +443,9 @@ struct BlockwiseGemmXdlops_pipeline_v1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i; ++i;
// Wait all wave produce next k-loop data // compute(idx=KRepeat) this K-loop can hide ds_write latency
block_sync_lds();
// Here 1 time read(idx=0) of next K-loop & compute(idx=KRepeat) this K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
/* Compute N */ /* Compute N */
vector_type<FloatAB, KPack> a_thread_vec; vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec; vector_type<FloatAB, KPack> b_thread_vec;
...@@ -495,6 +471,34 @@ struct BlockwiseGemmXdlops_pipeline_v1 ...@@ -495,6 +471,34 @@ struct BlockwiseGemmXdlops_pipeline_v1
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
// Wait all wave produce next k-loop data
block_sync_lds();
// Here 1 time prefetch read(idx=0) of next K-loop
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, I0, I0),
b_thread_buf);
});
});
__builtin_amdgcn_sched_group_barrier(0x020, 6, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 6, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 16, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 6, 0); // DS read
} while(i < (num_loop - 1)); } while(i < (num_loop - 1));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment