"vscode:/vscode.git/clone" did not exist on "8a2c69eeee29ffc4493654578c27214ecdddb64d"
Unverified Commit 79a5d9c1 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

[CK_TILE] FA bwd kernels optimization (#1397)



* tmp save

* fix batch deterministic bugs

* fix group deterministic bugs

* codegen update

* reorder files

* bias support

* hd256 bias support

* bwd smoke test update

* simplify convert dq

* fix hd256 dropout scratch

* do{}while() -> while(){}

* comments

* remove FmhaBwdTilePartitioner

* save clear_tile

* refactor dropout

* code cleanup

* code cleanup

* comments

* fix epilogue problem

* fix fwd dropout

* group convert_dq opt

* fix dq alignment

* Do not store storerandval in bwd for flash attention integration

* fix hd32 error and boost performance

* revert

* Remove duplicated WarpGemm definitions in the policy file

* dropout patch for mrepeat 16*16

* code sync up

* dq_acc stride

* dq_acc stride stuff

* codegen update

* fwd dropout revert

* fix hd128 scratches and boost performance

* receipt 3 for simplified smoke test

* more strides for fa integration

* fix hd64 scratches and boost performance

* non-iglp pipeline for headdim padding cases

* dpad same as dvpad for flash attention integration

* unpadded lse&d for group mode

* Support unpad layout for group lse

* Support unpad lse layout for splitkv

* Fix stride for splitkv kernel

* fix unpadded lse issue in fwd splitkv

* comment

* solve lds read&write conflicts

* rename

* bias rename

* tile index revert

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarQianfeng Zhang <Qianfeng.Zhang@amd.com>
parent 2581727d
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdOGradDotO struct BlockFmhaBwdOGradDotO
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO ...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static constexpr index_t kAlignmentO = static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -8,9 +8,8 @@ namespace ck_tile { ...@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching // This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum enum class BlockFmhaBwdPipelineEnum
{ {
KSKTSVR = 0, KRKTRVR_IGLP = 0,
QSKSVROGradS, KRKTRVR,
KSVR,
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -24,7 +24,9 @@ template <typename QDataType_, ...@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_, typename BiasGradDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaBwdPipelineProblem struct BlockFmhaBwdPipelineProblem
{ {
...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem ...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>; using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem ...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem ...@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename AccDataType_,
typename QGradDataType_,
index_t kBlockSize_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits ...@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdConvertQGradTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
......
...@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1 ...@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
// "wrong!"); "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment