Unverified Commit 9d69a099 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

[CK_TILE] Fix compiler related FA bwd issues (#1530)

* add barriers

* tail bias barriers

* adjust bf16/hd256 tol

* continue adjust bf16/hd256 tol
parent 42e6dcea
...@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[]) ...@@ -99,13 +99,26 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
auto get_elimit(int /*init_method*/) auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
{
rtol = 3.2e-2;
atol = 3.2e-2;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
...@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -899,7 +912,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result, bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref, dq_host_ref,
std::string("Error: QGrad Incorrect results!"), std::string("Error: QGrad Incorrect results!"),
......
...@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}, },
s_acc, s_acc,
bias_s_tile); bias_s_tile);
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
...@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>(); HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{}; auto dp_acc = SPGradBlockTileType{};
...@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>(); HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{}; auto ds = SPGradBlockTileType{};
...@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile); store_tile(dbias_dram_window, dbias_tile);
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
...@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window(ds_lds_read_window, {0, kK4}); move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>(); HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 7, SGrad@K^T Gemm4 // STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{}; auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); clear_tile(dq_acc);
...@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}); });
HotLoopScheduler::template GemmStagedScheduler<4>(); HotLoopScheduler::template GemmStagedScheduler<4>();
__builtin_amdgcn_sched_barrier(0);
// Results Scale // Results Scale
if constexpr(FmhaDropout::IsDropout) if constexpr(FmhaDropout::IsDropout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment