"docs/source/en/api/pipelines/stable_diffusion/overview.mdx" did not exist on "8eaaa546d89f836b716e92348786d878f883ee86"
Commit 72428037 authored by aska-0096's avatar aska-0096
Browse files

temp save

parent 4e92d44c
...@@ -1757,20 +1757,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1757,20 +1757,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){ if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){
if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
}
else{
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
} }
else if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
} }
}); });
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
}
else{
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
} }
else if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
} }
}); });
} }
...@@ -1784,13 +1788,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1784,13 +1788,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MFMA_INST = Gemm4MFMA; constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency // To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
static_for<0, MFMA_INST, 1>{}([&](auto i) { static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
}
else{
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
}
}
}); });
} }
...@@ -1843,11 +1852,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1843,11 +1852,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t QT_LDS_READ = static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>(); kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 = static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); // kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2;
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>(); static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 = static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); // kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2;
static constexpr index_t OGrad_LDS_READ = static constexpr index_t OGrad_LDS_READ =
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>(); kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment