Commit dcc3593f authored by danyao12's avatar danyao12
Browse files

fix hd32 error and boost performance

parent b2510c05
...@@ -448,7 +448,7 @@ class FmhaBwdDQDKDVKernel: ...@@ -448,7 +448,7 @@ class FmhaBwdDQDKDVKernel:
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16': if dtype == 'fp16' or dtype == 'bf16':
return { return {
# '32' : [FmhaBwdDQDKDVTileSize( 64, 64, 32, 64, 32, 64, 64, 32, 32, 1, 2, 1, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, 1), # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
# "kr_ktr_vr"], # "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1), '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1),
"kr_ktr_vr"], "kr_ktr_vr"],
......
...@@ -660,7 +660,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -660,7 +660,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}(); }();
// STAGE 3, P^T@OGrad^T Gemm1 // STAGE 3, P^T@OGrad^T Gemm1
pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer(); Policy::template PTFromGemm0CToGemm1A<Problem,
decltype(pt_reg_tensor),
decltype(pt_gemm)>(pt_reg_tensor, pt_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
auto qt_reg_tensor = load_tile(qt_lds_read_window); auto qt_reg_tensor = load_tile(qt_lds_read_window);
...@@ -732,7 +734,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -732,7 +734,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
const auto dst_gemm = cast_tile<GemmDataType>(dst); const auto dst_gemm = cast_tile<GemmDataType>(dst);
dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer(); Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(dst_gemm)>(dst_reg_tensor, dst_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
...@@ -908,8 +912,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -908,8 +912,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
} }
}(); }();
pt_reg_tensor.get_thread_buffer() = pt_gemm.get_thread_buffer(); Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(pt_gemm)>(
auto dot_reg_tensor = load_tile(dot_lds_read_window); pt_reg_tensor, pt_gemm);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>(); HotLoopScheduler::template GemmStagedScheduler<1>();
...@@ -965,7 +970,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -965,7 +970,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
const auto dst_gemm = cast_tile<GemmDataType>(dst); const auto dst_gemm = cast_tile<GemmDataType>(dst);
dst_reg_tensor.get_thread_buffer() = dst_gemm.get_thread_buffer(); Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(dst_gemm)>(dst_reg_tensor, dst_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, dst_gemm); store_tile(ds_lds_window, dst_gemm);
......
...@@ -1508,6 +1508,116 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1508,6 +1508,116 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return ds_block_dstr; return ds_block_dstr;
} }
template <typename Problem, typename PTOutTensor, typename PTInTensor>
CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out,
const PTInTensor& pt_in)
{
if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16)
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto pt_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
pt_warp_tensor.get_thread_buffer() = pt_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
});
}
else
{
pt_out.get_thread_buffer() = pt_in.get_thread_buffer();
}
}
template <typename Problem, typename SGradTOutTensor, typename SGradTInTensor>
CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out,
const SGradTInTensor& dst_in)
{
if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16)
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto dst_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
dst_warp_tensor.get_thread_buffer() = dst_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});
});
}
else
{
dst_out.get_thread_buffer() = dst_in.get_thread_buffer();
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment