"vscode:/vscode.git/clone" did not exist on "55862cfc33d841422341fbe65538544a08fa0bab"
Commit e9f29196 authored by root's avatar root
Browse files

code backup

parent cdb83933
......@@ -195,13 +195,12 @@ struct FusedMoeKernel
index_t stride_d;
index_t stride_o;
index_t stride_expert_gu;
index_t stride_expert_gu;
index_t stride_expert_d;
};
using Hargs = FusedMoeCommonHargs;
CK_TILE_HOST static constexpr ToKargs(const Hargs hargs) { return kargs; }
CK_TILE_HOST static constexpr auto ToKargs(const Hargs hargs) { return hargs; }
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu)
{
......
......@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "fused_moe_pipeline_nsplit2_policy.hpp"
#include "fused_moe_pipeline_problem.hpp"
#include "fused_moe_tile_shape.hpp"
#include "fused_moe_traits.hpp"
#include "fused_moe_weight_permute_enum.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck_tile {
......@@ -16,7 +19,7 @@ namespace ck_tile {
This pipeline split the gemm-n of B matrix for less register pressure
(assume B matrix is much larger than A)
*/
template <typename Problem_, typename Policy_ = FusedMoePipelineNSplit2Policy>
template <typename Problem_, typename Policy_ = ck_tile::FusedMoePipelineNSplit2Policy>
struct FusedMoePipelineNSplit2
{
using Problem = remove_cvref_t<Problem_>;
......@@ -28,6 +31,7 @@ struct FusedMoePipelineNSplit2
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using YDataType = remove_cvref_t<typename Problem::AccDataType>;
using ScaleDataType = remove_cvref_t<typename Problem::ScaleDataType>;
using FusedMoeTileShape = remove_cvref_t<typename Problem::FusedMoeTileShape>;
......@@ -70,29 +74,24 @@ struct FusedMoePipelineNSplit2
static constexpr index_t kWarpRepeatN_1 = FusedMoeTileShape::kWarpRepeatN_1;
static constexpr index_t kWarpRepeatK_1 = FusedMoeTileShape::kWarpRepeatK_1;
using MBlockType_0 = decltype(Policy::GetMatrixCoreSwizzledBlockTIle_0<Problem>());
static constexpr index_t kBlockNr_0 = MBlockType_0 {}
::at(number<0>{});
static constexpr index_t kBlockKr_0 = MBlockType_0 {}
::at(number<1>{});
static constexpr index_t kBlockWaveFlatten = MBlockType_0 {}
::at(number<2>{});
using MBlockType_0 = decltype(Policy::template GetMatrixCoreSwizzledBlockTIle_0<Problem>());
static constexpr index_t kBlockNr_0 = MBlockType_0::at(number<0>{});
static constexpr index_t kBlockKr_0 = MBlockType_0::at(number<1>{});
static constexpr index_t kBlockWaveFlatten = MBlockType_0::at(number<2>{});
static_assert(kBlockNr_0 % 2 == 0);
static constexpr index_t kBlockSubNr_0 = kBlockNr_0 / 2;
using MBlockType_1 = decltype(Policy::GetMatrixCoreSwizzledBlockTIle_1<Problem>());
static constexpr index_t kBlockNr_1 = MBlockType_1 {}
::at(number<0>{});
static constexpr index_t kBlockKr_1 = MBlockType_1 {}
::at(number<1>{});
using MBlockType_1 = decltype(Policy::template GetMatrixCoreSwizzledBlockTIle_1<Problem>());
static constexpr index_t kBlockNr_1 = MBlockType_1::at(number<0>{});
static constexpr index_t kBlockKr_1 = MBlockType_1::at(number<1>{});
static constexpr index_t kBlockSubKr_1 = kBlockKr_1 / 2;
static_assert(kBlockSubNr_0 == kBlockSubKr_1);
static constexpr index_t kAlignmentA = Policy::GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::GetAlignment_G<Problem>();
static constexpr index_t kAlignmentU = Policy::GetAlignment_U<Problem>();
static constexpr index_t kAlignmentD = Policy::GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::GetAlignment_O<Problem>();
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentU = Policy::template GetAlignment_U<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
......@@ -106,12 +105,12 @@ struct FusedMoePipelineNSplit2
static constexpr const char* name = "fused_moe_ns2";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
// using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSingleBuffer()
{
return Policy<Problem>::GetSmemSizeSingleBuffer();
return Policy::template GetSmemSizeSingleBuffer();
}
// this is the thread-offset along row/col
......@@ -141,7 +140,7 @@ struct FusedMoePipelineNSplit2
const DGlobalTileWindow& d_gtile_window_tmp,
OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr,
ScaleDataType scale,
//ScaleDataType scale,
CK_TILE_LDS_ADDR void* smem_0,
CK_TILE_LDS_ADDR void* smem_1,
index_t dim_size,
......@@ -231,8 +230,8 @@ struct FusedMoePipelineNSplit2
statically_indexed_array<g_thread_type, 2> g_tls;
statically_indexed_array<u_thread_type, 2> u_tls;
using WarpGemm0 = Policy::GetWarpGemm0<Problem>();
using WarpGemm1 = Policy::GetWarpGemm1<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(Policy::template GetWarpGemm0<Problem>())>;
using WarpGemm1 = remove_cvref_t<decltype(Policy::template GetWarpGemm1<Problem>())>;
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
......@@ -270,8 +269,8 @@ struct FusedMoePipelineNSplit2
move_tile_window(d_win, {number<0>{}, number<kBlockKr_0>{}, number<0>{}});
};
auto acc_g = generate_tuple([&](auto) { MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
auto acc_u = generate_tuple([&](auto) { MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
auto acc_g = generate_tuple([&](auto) {Policy::template MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
auto acc_u = generate_tuple([&](auto) {Policy::template MakeCBlockTile_Gemm0<Problem>(); }, number<2>{});
// Note this function only do gemm of single Nsplit
// clang-format off
......@@ -408,8 +407,8 @@ struct FusedMoePipelineNSplit2
sweep_tile_span(acc_spans_0[number<0>{}], [&](auto idx0) {
sweep_tile_span(acc_spans_0[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
element_wise::Silu{}(acc_g[I0](i_j_idx), acc_g[I0](i_j_idx));
element_wise::Silu{}(acc_g[I1](i_j_idx), acc_g[I1](i_j_idx));
ck::tensor_operation::element_wise::Silu{}(acc_g[I0](i_j_idx), acc_g[I0](i_j_idx));
ck::tensor_operation::element_wise::Silu{}(acc_g[I1](i_j_idx), acc_g[I1](i_j_idx));
acc_g[I0](i_j_idx) *= acc_u[I0](i_j_idx);
acc_g[I1](i_j_idx) *= acc_u[I1](i_j_idx);
});
......@@ -419,7 +418,7 @@ struct FusedMoePipelineNSplit2
if constexpr(std::is_same_v<YDataType, fp16_t>) return impl::cast_tile_pk_fp16_fp32<YDataType>(acc_g[i]);
else return cast_tile<YDataType>(acc_g[i]); }, number<2>{});
auto acc_d = MakeCBlockTile_Gemm1<Problem>();
auto acc_d = Policy::template MakeCBlockTile_Gemm1<Problem>();
// TODO: reshuffle? 32x32x8 mfma can avlid LDS reshuffle
// Second gemm
......
......@@ -73,16 +73,16 @@ struct FusedMoeTileShape
static constexpr index_t kWarpM_0 = Gemm0WarpTile::at(number<0>{});
static constexpr index_t kWarpN_0 = Gemm0WarpTile::at(number<1>{});
static constexpr index_t kWarpK_0 = Gemm0WarpTile::at(number<2>{});
static constexpr index_t kBlockWarpsM_0 = Gemm0BlockWarps::at(numner<0>{});
static constexpr index_t kBlockWarpsN_0 = Gemm0BlockWarps::at(numner<1>{});
static constexpr index_t kBlockWarpsK_0 = Gemm0BlockWarps::at(numner<2>{});
static constexpr index_t kBlockWarpsM_0 = Gemm0BlockWarps::at(number<0>{});
static constexpr index_t kBlockWarpsN_0 = Gemm0BlockWarps::at(number<1>{});
static constexpr index_t kBlockWarpsK_0 = Gemm0BlockWarps::at(number<2>{});
static constexpr index_t kSubBlockM_0 = kWarpM_0 * kBlockWarpsM_0;
static constexpr index_t kSubBlockN_0 = kWarpN_0 * kBlockWarpsN_0;
static constexpr index_t kSubBlockK_0 = kWarpK_0 * kBlockWarpsK_0;
static_assert(kBlockM_0 % kSubBlockM_0 == 0);
static_assert(kBlockN_0 % kSubBlockN_0 == 0);
static_assert(kBlockK_0 % kSubBlockK_0 == 0);
static constexpr index_t kWarpRepeatM_0 = kBlockM_0 / kSubBlockM_0;
static constexpr index_t kWarpRepeatM_0 = kBlockM_0 / kSubBlockM_0;//warp repeat is block repeat
static constexpr index_t kWarpRepeatN_0 = kBlockN_0 / kSubBlockN_0;
static constexpr index_t kWarpRepeatK_0 = kBlockK_0 / kSubBlockK_0;
......@@ -93,9 +93,9 @@ struct FusedMoeTileShape
static constexpr index_t kWarpM_1 = Gemm1WarpTile::at(number<0>{});
static constexpr index_t kWarpN_1 = Gemm1WarpTile::at(number<1>{});
static constexpr index_t kWarpK_1 = Gemm1WarpTile::at(number<2>{});
static constexpr index_t kBlockWarpsM_1 = Gemm1BlockWarps::at(numner<0>{});
static constexpr index_t kBlockWarpsN_1 = Gemm1BlockWarps::at(numner<1>{});
static constexpr index_t kBlockWarpsK_1 = Gemm1BlockWarps::at(numner<2>{});
static constexpr index_t kBlockWarpsM_1 = Gemm1BlockWarps::at(number<0>{});
static constexpr index_t kBlockWarpsN_1 = Gemm1BlockWarps::at(number<1>{});
static constexpr index_t kBlockWarpsK_1 = Gemm1BlockWarps::at(number<2>{});
static constexpr index_t kSubBlockM_1 = kWarpM_1 * kBlockWarpsM_1;
static constexpr index_t kSubBlockN_1 = kWarpN_1 * kBlockWarpsN_1;
static constexpr index_t kSubBlockK_1 = kWarpK_1 * kBlockWarpsK_1;
......
......@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
#include "fused_moe_weight_permute_enum.hpp"
namespace ck_tile {
......
......@@ -2,9 +2,15 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe.hpp"
#include "ck_tile/host.hpp"
#include "rotary.hpp"
#include "utils.hpp"
//#include "rotary.hpp"
//#include "utils.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_tile_shape.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
#include "include/ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp"
#include <array>
#include <cstring>
......@@ -15,7 +21,7 @@
#include <tuple>
#include <utility>
#include <vector>
#include <torch/torch.h>
//#include <torch/torch.h>
//test args
auto create_args(int argc, char* argv[])
{
......@@ -106,14 +112,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
shard_intermediate_size/2,
hidden_size});
ck_tile::HostTensor<DDataType> d_host({num_experts,
hidden_size
hidden_size,
shard_intermediate_size/2});
ck_tile::reference_permute<GDataType>(g_host_ref, g_host, {0, 1, 3, 4, 2, 5})
ck_tile::reference_permute<GDataType>(u_host_ref, u_host, {0, 1, 3, 4, 2, 5})
ck_tile::reference_permute<GDataType>(d_host_ref, d_host, {0, 1, 3, 4, 2, 5})
ck_tile::reference_permute<GDataType>(g_host_ref, g_host, {0, 1, 3, 4, 2, 5});
ck_tile::reference_permute<GDataType>(u_host_ref, u_host, {0, 1, 3, 4, 2, 5});
ck_tile::reference_permute<GDataType>(d_host_ref, d_host, {0, 1, 3, 4, 2, 5});
ck_tile::HostTensor<ODataType> o_host({num_tokens, hidden_size});
ck_tile::HostTensor<FP32> sorted_weights({num_tokens,topk});
ck_tile::HostTensor<ck_tile::fp32_t> sorted_weights({num_tokens,topk});
ck_tile::HostTensor<ck_tile::index_t> sorted_topk_ids({num_tokens,topk});
ck_tile::HostTensor<ck_tile::index_t> sorted_expert_ids({num_tokens,topk});
ck_tile::HostTensor<ck_tile::index_t> sorted_num_tokens_post_padded({1});
......@@ -161,26 +167,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.stride_expert_gu = stride_expert_gu;
args.stride_expert_d = stride_expert_d;
args.dim_size = dim_size;
// args.dim_size = dim_size;
args.hidden_size = hidden_size;
args.num_tokens = num_tokens; // input number of tokens for current iteration
args.num_experts = num_experts;
}
};
//
constexpr ck_tile::index_t ts_experts = experts_;
// constexpr ck_tile::index_t ts_experts = experts_;
//tiling
using moe_block_tile_0 = ck::Sequence<32, // kM_a
using moe_block_tile_0 = ck_tile::sequence<32, // kM_a
128, // kN_g/u
128, // kN_sub0
32, // kK_a
128 // kN_d
>;
using moe_block_warps0_0 = ck::Sequence<1, 4, 1>;//mnk
using moe_block_warps1_0 = ck::Sequence<4, 1, 1>;
using moe_warp_tile_0 = ck::Sequence<32, 32, 16>;
using moe_block_warps0_0 = ck_tile::sequence<1, 4, 1>;//mnk
using moe_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using moe_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
// using fmha_warp_tile_4 = ck::Sequence<32, 32, 8>;
using moe_shape = ck::tile_program::FusedMoeTileShape<moe_block_tile_0,
using moe_shape = ck_tile::FusedMoeTileShape<moe_block_tile_0,
moe_block_warps0_0,
moe_warp_tile_0,
moe_block_warps1_0,
......@@ -188,10 +194,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
using moe_traits = ck_tile::FusedMoeTraits<false,//down preshuffle
-1, // index_t kBlockPerCu_ = ,overwrite occupancy if not -1
0,//index_t OAtomic_
FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv//FusedMoeWeightPermuteEnum WeightPermute_ =
ck_tile::FusedMoeWeightPermuteEnum::permute_b_nr_kr_kw_nw_kv//FusedMoeWeightPermuteEnum WeightPermute_ =
>;
using moe_problem = ck_tile::FusedMoePipelineProblem<ADataType, GDataType, UDataType, DDataType,
ODataType, AccDataType, ScaleDataType, GateActivation, moe_shape, moe_traits>;
ODataType, AccDataType, ScaleDataType, ck::tensor_operation::element_wise::Silu, moe_shape, moe_traits>;
using moe_pipeline = ck_tile::FusedMoePipelineNSplit2<moe_problem>;
using Hargs = ck_tile::FusedMoeKernel::FusedMoeCommonHargs;
using moe_partitioner = ck_tile::FusedMoeTilePartitioner_PersistentSplitD<moe_shape>; \
......@@ -240,4 +246,4 @@ int main(int argc, char* argv[])
//call run
//return
\ No newline at end of file
//return
......@@ -6,6 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_nsplit2.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_tile_shape.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
#include "include/ck_tile/ops/fused_moe/pipeline/fused_moe_weight_permute_enum.hpp"
#include <type_traits>
......
# fused multi-head attention
This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_fmha_fwd -j
```
This will result in an executable `build/bin/tile_example_fmha_fwd`
## kernel
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
There are 3 template parameters for this kernel template.
* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose.
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support.
## codegen
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.
## executable
`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all supported args. Below is an example of the output (may subject to change)
```
args:
-v weather do CPU validation or not (default:1)
-mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2)
-h num of head, for q (default:8)
-h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
't:l,r', top-left sliding window attn(swa) with FA style left right size
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0)
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
### hdim
Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default)
### group/batch mode
Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below)
### MQA/GQA
By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers.
### input/output permute, and `b*s*3*h*d`
If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`.
### attention bias
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
### alibi
alibi is supported
### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
### vlayout
We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts.
### attention mask
we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right.
Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.
![](misc/gamc.png)
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
| mask case| cmdline | FA style | xformer style |
|----------|:-------------:|:-------------:|:-------------:|
| no mask | `-mask=0`(default) | | |
| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` |
| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` |
| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` |
| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` |
Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.
### dropout
TBD
## FP8 experimental support
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+.
Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment