Commit a5670e67 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/fav2_fwd_sept

parents 75b09986 9d69a099
...@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){ ...@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){
def retimage def retimage
(retimage, image) = getDockerImage(conf) (retimage, image) = getDockerImage(conf)
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 48, unit: 'HOURS') timeout(time: 48, unit: 'HOURS')
{ {
...@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){ ...@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try { try {
(retimage, image) = getDockerImage(conf) (retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) { withDockerContainer(image: image, args: dockerOpts) {
...@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){ ...@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try { try {
(retimage, image) = getDockerImage(conf) (retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) { withDockerContainer(image: image, args: dockerOpts) {
...@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){ ...@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){
def variant = env.STAGE_NAME def variant = env.STAGE_NAME
def retimage def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try { try {
(retimage, image) = getDockerImage(conf) (retimage, image) = getDockerImage(conf)
} }
...@@ -838,7 +838,7 @@ pipeline { ...@@ -838,7 +838,7 @@ pipeline {
dbsshport = "${dbsshport}" dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}" dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}" dbsshpassword = "${dbsshpassword}"
status_wrapper_creds = "${status_wrapper_creds}" ck_git_creds = "${ck_git_creds}"
gerrit_cred="${gerrit_cred}" gerrit_cred="${gerrit_cred}"
DOCKER_BUILDKIT = "1" DOCKER_BUILDKIT = "1"
} }
......
rocm-docs-core==1.8.1 rocm-docs-core==1.8.2
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.8.1 rocm-docs-core==1.8.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[]) ...@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
auto get_elimit(int /*init_method*/) auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
{
rtol = 3.2e-2;
atol = 3.2e-2;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
...@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result, bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref, dq_host_ref,
std::string("Error: QGrad Incorrect results!"), std::string("Error: QGrad Incorrect results!"),
......
...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<1>() __device__ constexpr auto TailScheduler<1>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<2>() __device__ constexpr auto TailScheduler<2>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
......
...@@ -324,55 +324,55 @@ struct DppSelector ...@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp(); static constexpr auto GetDpp();
template <> template <>
static constexpr auto GetDpp<half_t, 8, 32>() constexpr auto GetDpp<half_t, 8, 32>()
{ {
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 8, 16>() constexpr auto GetDpp<half_t, 8, 16>()
{ {
return DppInstr::dpp8_f16_8x16x2; return DppInstr::dpp8_f16_8x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() constexpr auto GetDpp<half_t, 16, 16>()
{ {
return DppInstr::dpp8_f16_16x16x2; return DppInstr::dpp8_f16_16x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 32, 8>() constexpr auto GetDpp<half_t, 32, 8>()
{ {
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 1, 32>() constexpr auto GetDpp<half_t, 1, 32>()
{ {
return DppInstr::dpp8_f16_1x32x2; return DppInstr::dpp8_f16_1x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 32>() constexpr auto GetDpp<half_t, 2, 32>()
{ {
return DppInstr::dpp8_f16_2x32x2; return DppInstr::dpp8_f16_2x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 16>() constexpr auto GetDpp<half_t, 2, 16>()
{ {
return DppInstr::dpp8_f16_2x16x2; return DppInstr::dpp8_f16_2x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 16>() constexpr auto GetDpp<half_t, 4, 16>()
{ {
return DppInstr::dpp8_f16_4x16x2; return DppInstr::dpp8_f16_4x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 32>() constexpr auto GetDpp<half_t, 4, 32>()
{ {
return DppInstr::dpp8_f16_4x32x2; return DppInstr::dpp8_f16_4x32x2;
} }
......
...@@ -415,7 +415,7 @@ struct WmmaSelector ...@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>() constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
...@@ -425,7 +425,7 @@ struct WmmaSelector ...@@ -425,7 +425,7 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
...@@ -435,19 +435,19 @@ struct WmmaSelector ...@@ -435,19 +435,19 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>() constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>() constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
...@@ -458,7 +458,7 @@ struct WmmaSelector ...@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>() constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -651,97 +651,97 @@ struct MfmaSelector ...@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
static constexpr auto GetMfma<double, 16, 16>() constexpr auto GetMfma<double, 16, 16>()
{ {
return MfmaInstr::mfma_f64_16x16x4f64; return MfmaInstr::mfma_f64_16x16x4f64;
} }
template <> template <>
static constexpr auto GetMfma<float, 64, 64>() constexpr auto GetMfma<float, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 64>() constexpr auto GetMfma<float, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 64>() constexpr auto GetMfma<float, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x1xf32; return MfmaInstr::mfma_f32_16x16x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 8, 64>() constexpr auto GetMfma<float, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 4, 64>() constexpr auto GetMfma<float, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 32>() constexpr auto GetMfma<float, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x2xf32; return MfmaInstr::mfma_f32_32x32x2xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 16>() constexpr auto GetMfma<float, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x4xf32; return MfmaInstr::mfma_f32_16x16x4xf32;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 64, 64>() constexpr auto GetMfma<half_t, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 64>() constexpr auto GetMfma<half_t, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x16f16; return MfmaInstr::mfma_f32_16x16x16f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 64>() constexpr auto GetMfma<half_t, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x4f16; return MfmaInstr::mfma_f32_16x16x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 8, 64>() constexpr auto GetMfma<half_t, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 4, 64>() constexpr auto GetMfma<half_t, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -751,7 +751,7 @@ struct MfmaSelector ...@@ -751,7 +751,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 16, 16>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -762,72 +762,72 @@ struct MfmaSelector ...@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940) #if defined(CK_USE_AMD_MFMA_GFX940)
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x16i8; return MfmaInstr::mfma_i32_32x32x16i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x32i8; return MfmaInstr::mfma_i32_16x16x32i8;
} }
#else #else
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x16i8; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif #endif
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() constexpr auto GetMfma<f8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8f8; return MfmaInstr::mfma_f32_32x32x16f8f8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16>() constexpr auto GetMfma<f8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8bf8; return MfmaInstr::mfma_f32_32x32x16bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16>() constexpr auto GetMfma<bf8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8bf8; return MfmaInstr::mfma_f32_32x32x16f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>() constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8f8; return MfmaInstr::mfma_f32_32x32x16bf8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>() constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
......
...@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}, },
s_acc, s_acc,
bias_s_tile); bias_s_tile);
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>(); HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{}; auto dp_acc = SPGradBlockTileType{};
...@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>(); HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{}; auto ds = SPGradBlockTileType{};
...@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile); store_tile(dbias_dram_window, dbias_tile);
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
...@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window(ds_lds_read_window, {0, kK4}); move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>(); HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 7, SGrad@K^T Gemm4 // STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{}; auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); clear_tile(dq_acc);
...@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}); });
HotLoopScheduler::template GemmStagedScheduler<4>(); HotLoopScheduler::template GemmStagedScheduler<4>();
__builtin_amdgcn_sched_barrier(0);
// Results Scale // Results Scale
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
......
...@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
{ {
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K // Comp: Q x K
...@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
{ {
// Mem: Q^T LDS load // Mem: Q^T LDS load
// Comp: OGrad x V // Comp: OGrad x V
...@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
{ {
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store // Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad // Comp: PT x OGrad
...@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
{ {
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load. // Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT // Comp: SGradT x QT
...@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
} }
template <> template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>() CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
{ {
// Mem: SGrad, OGrad, D LDS load. // Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT // Comp: SGrad x KT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment