Unverified Commit 14c3cfb1 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

[CK_TILE] Improve headdim96 performance for fmha-bwd (#1573)



* Add kQKHeaddimForGemmN and kVHeaddimForGemmN in order to support headdim 96

* Remove the using of MakeKRegBlockDescriptor and MakeVRegBlockDescriptor

* Fix in bwd_piple_default_policy

* Remove kQKHeaddim and rename kQKHeaddimForGemmN to kQKHeaddim in the bwd kernel and pipelines

* Replace kVHeaddimForGemmN by kVHeaddim and kDoDvHeaddim

* Update to hd96 tile settings

* Add smoke test scripts for fmha-bwd hd96

* Revert "Add smoke test scripts for fmha-bwd hd96"

This reverts commit 7ca7e1a93dc65eb99ce3ff4e82693589830e42a2.

* Remove hd96 tile settings in fmha_bwd codegen to save compiling

* Fix lost code line in bwd_pipeline_default_policy

* Merge kDoDvHeaddim/kPadHeadDimDoDv to kVHeaddim/kPadHeadDimV and remove TileFmhaBwdTraits

* Rename KRegSliceBlockDescriptor/VRegSliceBlockDescriptor to KRegBlockDescriptor/VRegBlockDescriptor

* tiny adjustments

---------
Co-authored-by: default avatarPo Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatardanyao12 <Dan.Yao@amd.com>
parent 10158b0f
...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>()); k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window = auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window = auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}), make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(), k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>( auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
...@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>()); v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window = auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0}); make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window = auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}), make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(), v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>()); Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------ //------------------------------------------------------------------
...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window( auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>( auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds(); block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window); auto v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds(); block_sync_lds();
//---------------------------- Loop Load in ----------------------------// //---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS // Q: HBM ->Reg ->LDS
...@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>()); q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window = auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window = auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(), make_tile_window(q_lds_window.get_bottom_tensor_view(),
...@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window( auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>( auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
...@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>()); do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window = auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window = auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(), make_tile_window(do_lds_window.get_bottom_tensor_view(),
...@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window( auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>( auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
...@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t i_total_loops = 0; index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start; index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1"); static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3"); static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
......
...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>()); k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window = auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto k_lds_read_window = auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}), make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(), k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>( auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>()); Policy::template MakeKRegBlockDescriptor<Problem>());
...@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -204,15 +204,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>()); v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window = auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0}); make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
auto v_lds_read_window = auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}), make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(), v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>()); Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------ //------------------------------------------------------------------
...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window( auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0}); shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>( auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>()); kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds(); block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window); auto v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------// //---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS // Q: HBM ->Reg ->LDS
auto q_dram_window = auto q_dram_window =
...@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>()); q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window = auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read_window = auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(), make_tile_window(q_lds_window.get_bottom_tensor_view(),
...@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window( auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0}); shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>( auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>()); qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
...@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>()); do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window = auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read_window = auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(), make_tile_window(do_lds_window.get_bottom_tensor_view(),
...@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window( auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0}); shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>( auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>()); dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
...@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ...@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t i_total_loops = 0; index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start; index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1"); static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3"); static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
......
...@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
...@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
...@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
...@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
...@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
...@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentK<Problem>(); return total_pixels / GetAlignmentK<Problem>();
...@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
...@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentV<Problem>(); constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>(); constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>(); constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{ {
...@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
...@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kVPack = GetSmemKPackV<Problem>(); constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>(); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{ {
...@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
...@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
...@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>(); constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
// Hold full block data // Hold full block data
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
...@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM = static constexpr index_t WarpGemmM =
...@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute // Compute
static constexpr index_t Gemm0MFMA = static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim / kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA = static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 / kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA = static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 / kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
...@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>(); kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 = static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ = static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 = static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ = static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>(); kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write // LDS Write
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment