"vscode:/vscode.git/clone" did not exist on "f65c4bc3c0dd132e521dba5be8bdcecfefd44b4d"
Unverified Commit e8927110 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #1419 from ROCm/ck_tile/fa_bwd_opt_clean

Remove duplicated codes for creating WarpGemm
parents ed8ef7e5 5a561b5e
...@@ -815,15 +815,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -815,15 +815,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher< using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
typename Problem::QDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
...@@ -853,15 +847,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -853,15 +847,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher< using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
typename Problem::QDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
...@@ -902,15 +890,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -902,15 +890,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher< using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
typename Problem::OGradDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::VDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
...@@ -940,15 +922,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -940,15 +922,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher< using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
typename Problem::OGradDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::VDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
...@@ -1029,14 +1005,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1029,14 +1005,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor()
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
...@@ -1077,15 +1048,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1077,15 +1048,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher< using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
typename Problem::QDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
...@@ -1167,14 +1132,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1167,14 +1132,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor()
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::QDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
...@@ -1204,14 +1164,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1204,14 +1164,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor()
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::QDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
...@@ -1300,15 +1255,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1300,15 +1255,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher< using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
typename Problem::OGradDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::VDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
...@@ -1389,14 +1338,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1389,14 +1338,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor()
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::OGradDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
...@@ -1427,14 +1371,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1427,14 +1371,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor()
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::OGradDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
...@@ -1474,14 +1413,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1474,14 +1413,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor()
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
...@@ -1514,14 +1448,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1514,14 +1448,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16) if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16)
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::OGradDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
...@@ -1569,14 +1498,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1569,14 +1498,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16) if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16)
{ {
using WarpGemm = using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::QDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment