Commit 76cf795a authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'develop' into codegen_target

parents f42a8811 f16ebf82
...@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>; Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim, Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>; Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>; Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>; Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>; Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if no work to do
if constexpr(FmhaMask::IsMasking) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
{ {
......
...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment