"include/vscode:/vscode.git/clone" did not exist on "fbc576b54403d63efc2c922664ebbb5a21887b02"
Commit 2a4c2316 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_asm_bwd

parents 1e01ee09 770d2b77
...@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -25,14 +25,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::QDataType, typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kK0>,
Problem::BlockFmhaShape::kK0>>; typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -57,14 +58,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
TileGemmShape<Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kVHeaddim, Problem::BlockFmhaShape::kK1>,
Problem::BlockFmhaShape::kK1>>; typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -88,14 +90,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kK2>,
Problem::BlockFmhaShape::kK2>>; typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -120,14 +123,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
TileGemmShape<Problem::BlockFmhaShape::kN0, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kK3>,
Problem::BlockFmhaShape::kK3>>; typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -151,14 +155,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using BlockGemmProblem = using BlockGemmProblem = BlockGemmPipelineProblem<
BlockGemmPipelineProblem<typename Problem::GemmDataType, typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
Problem::kBlockSize, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
TileGemmShape<Problem::BlockFmhaShape::kM0, Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kQKHeaddim, Problem::BlockFmhaShape::kK4>,
Problem::BlockFmhaShape::kK4>>; typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
......
...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS ...@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if no work to do
if constexpr(FmhaMask::IsMasking) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
{ {
......
...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit // check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK) if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{ {
if(num_total_loop <= 0) if(num_total_loop <= 0)
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
...@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -49,6 +49,10 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{ {
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} }
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
}
} }
}; };
......
This diff is collapsed.
This diff is collapsed.
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment