"library/vscode:/vscode.git/clone" did not exist on "3da5c19e629174c234fe86c17ebd04732ea548b7"
Unverified Commit 851c3ed1 authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

[CK_TILE] support alibi (#1269)



* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent 6d073d31
......@@ -17,7 +17,7 @@ add_custom_command(
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding tile_example ${EXAMPLE_NAME}")
message("adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
......
......@@ -30,27 +30,29 @@ args:
-mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2)
-h num of head, for q (default:8)
-h_k num of head, for k/v, 0 means equal to h (default:0)
-h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
-s_k seqlen_k, 0 means equal to s (default:0)
-s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128)
-d_v head dim for v, 0 means equal to d (default:0)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:2)
-range_k per-tensor quantization range of k. used if squant=1. (default:2)
-range_v per-tensor quantization range of v. used if squant=1. (default:2)
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias add bias or not (default:0)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
......@@ -59,11 +61,11 @@ args:
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0)
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
......@@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov
### attention bias
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
### alibi
alibi is supported
### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};
......@@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[])
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"0",
"num of head, for k/v, 0 means equal to h\n"
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
.insert("s_k", "0", "seqlen_k, 0 means equal to s")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s",
"0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
......@@ -71,7 +71,11 @@ auto create_args(int argc, char* argv[])
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not")
.insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask",
"0",
......@@ -153,7 +157,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k == 0)
if(nhead_k < 0)
nhead_k = nhead;
if(nhead % nhead_k != 0)
......@@ -164,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
if(seqlen_k == 0)
if(seqlen_k < 0)
seqlen_k = seqlen_q;
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v == 0)
if(hdim_v < 0)
hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
......@@ -208,9 +212,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std::string vlayout = arg_parser.get_str("vlayout");
bool use_bias = arg_parser.get_bool("bias");
bool lse = arg_parser.get_bool("lse");
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
int init_method = arg_parser.get_int("init");
......@@ -295,12 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
// will not be used for verification at all (but will be copied to device anyway).
ck_tile::HostTensor<BiasDataType> bias_host(
use_bias
bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
......@@ -341,6 +351,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
}
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
assert(slopes.size() == nhead);
if(bias.rank_info == 0)
{
// alibi in 1*h
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
}
else
{
// alibi in b*h
for(auto i_b = 0; i_b < batch; i_b++)
{
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
}
}
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
......@@ -350,6 +378,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
......@@ -357,6 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off
auto layout_str = [&](bool permute){
......@@ -372,9 +402,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s
<< ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
<< std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
......@@ -382,7 +412,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mode == mode_enum::group,
is_v_rowmajor,
mask.type,
use_bias,
bias.type,
lse,
squant};
......@@ -441,7 +471,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
......@@ -461,7 +492,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_q,
stride_k,
stride_v,
stride_bias,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
......@@ -564,8 +596,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::scales(scale_s));
if(use_bias)
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off
if(i_perm)
......@@ -582,6 +615,51 @@ bool run(const ck_tile::ArgParser& arg_parser)
SMPLComputeDataType>(
s_host_ref, bias_host_ref, s_host_ref);
}
else if(bias.type == bias_enum::alibi)
{
// alibi construct elementwise bias to verify
auto alibi_host = [&]() {
if(mask.type != mask_enum::no_mask)
{
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
0,
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
}
else
{
return ck_tile::Alibi<SaccDataType, true>{
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL};
}
}();
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
for(auto i_h = 0; i_h < nhead; i_h++)
{
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
alibi_host.slope = current_slope;
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
SaccDataType pixel = 0;
alibi_host.update(pixel, i_r, i_c);
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
}
}
}
// [nhead, real_seqlen_q, real_seqlen_k]
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
SaccDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
s_host_ref, alibi_bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask)
{
......
......@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include <type_traits>
template <typename DataType>
......@@ -86,7 +87,7 @@ struct fmha_fwd_args
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
......@@ -106,7 +107,7 @@ struct fmha_fwd_args
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
......@@ -219,7 +220,7 @@ template <ck_tile::index_t HDim_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
bool kHasBias_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kPadS_,
......@@ -240,7 +241,7 @@ struct fmha_fwd_traits_
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
......@@ -261,7 +262,7 @@ struct fmha_fwd_traits
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bool has_bias;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
......
......@@ -40,6 +40,19 @@ MASK_MAP = {
"generic" : "FmhaMasks::GenericMask"
}
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
......@@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = {
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
......@@ -213,7 +226,7 @@ class FmhaFwdApiTrait:
bk0blen : int
vlayout : str
mask : str
bias : str # true/false
bias : str #
lse : str #
squant : str #
spad : str
......@@ -241,8 +254,8 @@ class FmhaFwdApiTrait:
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
......@@ -297,7 +310,7 @@ class FmhaFwdPipeline:
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias == 't' : n += '_bias'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
......@@ -332,7 +345,8 @@ class FmhaFwdApiPool:
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
......@@ -400,7 +414,7 @@ class FmhaFwdKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
......@@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1)
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
}
else:
return None
......@@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
......@@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask))
else:
assert False
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -149,11 +149,9 @@ struct mask_info
return tmp;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi);
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
};
inline std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
......@@ -17,7 +17,7 @@ for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in 0 1 ; do
for bias in "n" "e" "a"; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
......@@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done
done
......@@ -37,9 +38,11 @@ done
done
for perm in 0 1 ; do
for bias in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done
done
done
done
......@@ -154,3 +154,8 @@
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
......@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
// TODO: this is hacky, we use u16
return __builtin_amdgcn_sad_u16(x, y, acc);
}
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
return (x > y ? (x - y) : (y - x)) + acc;
}
} // namespace ck_tile
......@@ -3,7 +3,9 @@
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionBiasEnum
{
NO_BIAS = 0,
ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale)
ALIBI = 2, // bias computed with position encoding, applied after scale
};
template <BlockAttentionBiasEnum>
struct BlockAttentionBiasEnumToStr;
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::NO_BIAS>
{
static constexpr const char* name = "";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ELEMENTWISE_BIAS>
{
static constexpr const char* name = "bias";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ALIBI>
{
static constexpr const char* name = "alibi";
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace ck_tile {
enum struct PositionEncodingEnum
{
NO = 0,
ALIBI = 1,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum struct AlibiMode
{
VERTICAL = 0,
FROM_TOP_LEFT = 1, // keep sync with mask enum
FROM_BOTTOM_RIGHT = 2,
};
template <typename DataType, bool RowMajor = true>
struct Alibi
{
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE Alibi(DataType slope_,
index_t y_total_,
index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL)
{
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope;
shift_left_up = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
}();
shift_right_down = [&]() {
if(RowMajor)
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0;
}
else
{
return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0;
}
}();
mode = mode_;
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{
if constexpr(RowMajor)
{
// at least 3 instructions per row
index_t current_zero_point =
mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(col_idx + shift_left_up),
0));
pixel += slope * position;
}
else
{
// at least 3 instructions per col;
index_t current_zero_point = mode == AlibiMode::VERTICAL
? row_idx + col_idx + shift_right_down
: col_idx + shift_right_down;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto position = type_convert<DataType>(sad(bit_cast<uint32_t>(current_zero_point),
bit_cast<uint32_t>(row_idx + shift_left_up),
0));
pixel += slope * position;
}
}
DataType slope; // float?
index_t shift_left_up; // always possitive
index_t shift_right_down; // always possitive
AlibiMode mode;
};
template <typename DataType>
struct EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size,
index_t window_right_size,
index_t y_total,
index_t x_total,
GenericAttentionMaskEnum mask_enum)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool is_causal = window_left_size < 0 && window_right_size == 0;
AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template <typename DataType>
CK_TILE_HOST std::vector<DataType> get_alibi_slopes(ck_tile::index_t nheads)
{
auto get_slopes_power_of_2 = [](ck_tile::index_t n) {
float start = std::powf(
static_cast<float>(2),
-std::powf(static_cast<float>(2), -static_cast<float>((integer_log2_floor(n) - 3))));
std::vector<DataType> rtn;
for(auto i = 0; i < n; i++)
{
rtn.push_back(static_cast<DataType>(start * std::powf(start, i)));
}
return rtn;
};
if(is_power_of_two_integer(nheads))
{
// power of 2 calculation
return get_slopes_power_of_2(nheads);
}
else
{
ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads);
auto v0 = get_slopes_power_of_2(closest_power_of_2);
auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2);
auto v1_sliced = [&](auto vec, ck_tile::index_t rem) {
std::vector<DataType> sliced;
for(ck_tile::index_t i = 0; i < static_cast<ck_tile::index_t>(vec.size()); i++)
{
if(i % 2 == 0)
sliced.push_back(vec[i]);
}
std::vector<DataType> sliced_2(sliced.begin(), sliced.begin() + rem);
return sliced_2;
}(v1, nheads - closest_power_of_2);
v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end());
return v0;
}
}
} // namespace ck_tile
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
......@@ -33,6 +34,7 @@ struct FmhaFwdKernel
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
......@@ -41,7 +43,7 @@ struct FmhaFwdKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
......@@ -81,7 +83,8 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
......@@ -136,6 +139,13 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_bias = 0;
};
struct FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const void* alibi_slope_ptr;
ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
};
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
......@@ -162,7 +172,11 @@ struct FmhaFwdKernel
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdBatchModeBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
......@@ -175,7 +189,11 @@ struct FmhaFwdKernel
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdCommonBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
......@@ -255,13 +273,18 @@ struct FmhaFwdKernel
batch_stride_v,
batch_stride_o};
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
kargs.batch_stride_bias = batch_stride_bias;
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
......@@ -345,12 +368,17 @@ struct FmhaFwdKernel
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
......@@ -421,14 +449,10 @@ struct FmhaFwdKernel
{
batch_offset_v = key_start;
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;
}
else
{
batch_offset_bias = key_start;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
......@@ -461,7 +485,7 @@ struct FmhaFwdKernel
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
......@@ -585,7 +609,7 @@ struct FmhaFwdKernel
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const BiasDataType* bias_ptr =
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
......@@ -654,6 +678,39 @@ struct FmhaFwdKernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
// WA i_batch capture structure binding before c++20
auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType slope =
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask)
{
return make_alibi_from_lr_mask<SaccDataType, true>(slope,
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type);
}
else
{
return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
}
}
else
{
return EmptyPositionEncoding<SaccDataType>{};
}
}();
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
......@@ -672,6 +729,7 @@ struct FmhaFwdKernel
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
}
......@@ -683,6 +741,7 @@ struct FmhaFwdKernel
bias_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
}
......
......@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
QSKSVS,
};
template <BlockFmhaPipelineEnum>
struct BlockFmhaPipelineEnumToStr;
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS>
{
static constexpr const char* name = "qr";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC>
{
static constexpr const char* name = "qr_async";
};
template <>
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
{
static constexpr const char* name = "qs";
};
} // namespace ck_tile
......@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasBias = Traits::kHasBias;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
......@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
......@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
k_block_tile = load_tile(k_dram_window);
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
......@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
......@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
......@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr);
}
......
......@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
......@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
{
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking)
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && kHasBias)
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && kHasBias)
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
......@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
// check early exit
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
{
if(num_total_loop <= 0)
{
......@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
......@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
......@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr);
}
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
......@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
......@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
FmhaMask mask,
PositionEncoding /*position_encoding*/,
float scale_s,
float descale_qk,
float descale_sv,
......@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
k_block_tile = load_tile(k_dram_window);
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
......@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
tile_elementwise_inout(
[&](auto& x, const auto& y) {
......@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile {
......@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kHasBias = Problem::kHasBias;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr index_t kBlockPerCu = []() {
......@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
......@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
......@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
k_block_tile = load_tile(k_dram_window);
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
......@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
......@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
......@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
......@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
......@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
......@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
......@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename LSEDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
......@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr) const
{
......@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
identity{},
identity{},
mask,
position_encoding,
scale_s,
smem_ptr);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment