Unverified Commit 0cb2e06d authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] fmha forward split-kv + combine kernels (#1338)



* FA fwd dropout

* FA bwd

* epilogue reuse

* CMakeLists update

* [CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>

* now fwd/bwd can build

* bwd alibi

* add bwd validation stream_config

* update generated filenames

* update bwd kernel launch

* CK_TILE_HOST_DEVICE in philox

* Transpose -> transpose

* format

* format

* format

* Generate the instance for FA required

* format

* fix error in WarpGemm

* Add num_splits option and dummy split-kv api method

* Generate fmha_fwd_splitkv()

* Add SplitKV kernel codegen logics

* Add SplitKV combine kernel codegen logics

* Fix mismatched return type

* Clean-up code

* Replace sentinel value before storing

* Fix wrong layout of LSE/LSEacc/Oacc

* Format codes

* Fix o_acc memory error

* Fix wrong kBlockSize used in policy

* Reduce # of combine kernels

* Fix split-kv combine kernel name

* Fix wrong LDS indexing logics

* Fix wrong loop counter step logic

* Undo vector size changes

* Remove no-longer used field

* Remove in-consistent comment

* Remove debug statements in example

* Remove more debug statements

* Add constness to local variables

* Clearn up generate.py

* Fix unstable clang-format comment

* Remove unused include directive

* Use shorter template parameter name

* Enable non-split-kv blobs

* Update license date

* Print num_splits conditionally

* Undo disabling data types

* Remove unnessary tile size for fp8

* Fix wrong pipeline args for fp8

* Fix example output format

* Remove more debug code in combine pipeline

* Add stride kernel arguments for LSE/O acc workspace

* Re-order split-kv pipeline call operator arguments

* Pass LSE/O strides in kernel argument

* Re-order pipeline call operator arguments

* Use tensor_descriptor to locate LSEacc elements

* Support providing invalid element for tensor view

* Set invalid element value for LSEacc tensor view

* Remove hand-written store_tile() code

* Remove necessary value-overwrite logic

* Add transposed lds descriptor

* Support load_tile() for tile_window_with_static_lengths<>

* Undo removing necessary value-overwrite logic

* Use read descriptor to locate lds elements

* Simplify pipeline source code

* Add constraint to kMaxSplits

* Default use kMaxSplits=64 in generate.py

* Revert "Add constraint to kMaxSplits"

This reverts commit 0a2132d758042e6fb0292f4e354909b8a4d1c118.

* Revert "Default use kMaxSplits=64 in generate.py"

This reverts commit c7d9c80b77320aec6559222bed7d47adcaefe4e3.

* Decide alignment by the padding parameter

* Remove no-longer used utility functions

* Remove not-working code

* Add comment & remove no-longer used code

* Fix computation errors

* Add heuristic to override num_splits option

* Add constraint to kMaxSplits

* Fix compilation error

* Clean up pipeline code

* Wrap pointer access as lambda function

* Rename confusing methods

* Use kLogMasSplits as template parameter

* Finish splitkv combine kernel codegen

* Update kMaxSplits limit

* Use smaller kM0 for splitkv combine kernel

* Ignore droupout flag in splitkv pipeline

* Unify flag usage

* Add back flag kStoreLSE

* Merge lambda calls in pipeline

* Fix compilation errors

* Avoid all empty splits

* Always check for empty loop in splitkv pipelines

* Re-order parameters

* Remove redundant p_drop option check

* Add traits/problem for fwd splitkv kernel

* Conditionally enable uneven split boundary checks

* Add comment for the splitkv traits field

* Change even split criteria

* Re-order statements

* Refine occupancy value for hdim=128&256

* Refine occupancy value for hdim=32&64

* Remove redundant kernel argument

* Separate fmha bwd codegen logics

* Separate fmha fwd codegen logics

* Remove redundant direction parameter in fwd&bwd codegen logics

* Support generate multiple APIs for an example

* Let 'api' an alias of 'direction' option

* Remove choices for the 'direction' option

* Use dictionary to config all the functions

* Move fmha splitkv codegen logics to other file

* Add fwd_splitkv api for tile_example_fmha_fwd

---------

Co-authored-by: danyao12 <danyao12>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 3e9711f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ true,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
} // namespace ck_tile
...@@ -54,4 +54,69 @@ struct BlockFmhaPipelineProblem ...@@ -54,4 +54,69 @@ struct BlockFmhaPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename BiasDataType,
typename RandValOutputDataType,
typename LSEDataType,
typename PDataType,
typename OaccDataType,
typename ODataType,
typename BlockFmhaShape,
bool kIsGroupMode,
typename FmhaMask,
typename Traits>
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
BiasDataType,
RandValOutputDataType,
LSEDataType,
PDataType,
OaccDataType,
ODataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
Traits>
{
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
};
template <typename LSEDataType_,
typename OaccDataType_,
typename ODataType_,
index_t HeadDimV_,
index_t kM0_,
index_t kN1_,
bool kIsGroupMode_,
typename Traits_>
struct BlockFmhaSplitKVCombinePipelineProblem
{
using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN1 = kN1_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr index_t kMaxSplits = Traits::kMaxSplits;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -32,6 +32,50 @@ struct TileFmhaTraits ...@@ -32,6 +32,50 @@ struct TileFmhaTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ /* padding for seqlen_q */,
bool kPadSeqLenK /* padding for seqlen_k */,
bool kPadHeadDimQ /* paddding for hdim_q */,
bool kPadHeadDimV /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum,
bool kHasBiasGrad,
bool kStoreLSE,
bool kHasDropout,
bool kDoFp8StaticQuant,
bool kHasUnevenSplits_ = true,
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
BiasEnum,
kHasBiasGrad,
kStoreLSE,
kHasDropout,
kDoFp8StaticQuant,
kBlockPerCu>
{
// determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kStoreLSE_,
bool kDoFp8StaticQuant_,
index_t kLogMaxSplits_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVCombineTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */> index_t kBlockPerCu_ = 2 /* hint to occupancy */>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment