Unverified Commit 8e4b351a authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to...


[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel (#12591)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 9869453c
# SPDX-License-Identifier: Apache-2.0
"""Tests for the triton_flash_attention kernel
Run `pytest tests/kernels/test_triton_flash_attention.py`.
"""
import pytest
import torch
from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS,
MetaData,
compute_alibi_tensor,
scale_fp8,
triton_attention_rocm)
from vllm.platforms import current_platform
class ReferenceAttention:
def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype,
input_metadata):
self.Z = Z
self.HQ = HQ
self.HK = HK
self.N_CTX_Q = N_CTX_Q
self.N_CTX_K = N_CTX_K
self.D_HEAD = D_HEAD
self.use_alibi = use_alibi
self.dtype = dtype
self.input_metadata = input_metadata
def fwd(self, q, k, v):
scores = torch.einsum('bhqd,bhkd->bhqk', q,
k).float() * self.input_metadata.sm_scale
if self.input_metadata.causal:
mask = torch.tril(torch.ones(self.N_CTX_Q,
self.N_CTX_K,
device="cuda"),
diagonal=self.N_CTX_K - self.N_CTX_Q)
scores[:, :, mask == 0] = float("-inf")
if self.input_metadata.bias is not None:
scores += self.input_metadata.bias
if self.use_alibi:
scores += compute_alibi_tensor(self.input_metadata.alibi_slopes,
self.N_CTX_Q, self.N_CTX_K)
p = torch.softmax(scores, dim=-1)
if self.input_metadata.causal:
# If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going
# into softmax. This creates a row of NaNs as -inf - -inf == NaN.
# So we fix this by converting the NaNs to 0s, which is what they
# should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask == 1] = 0
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v)
# compare
if self.input_metadata.layout == 'bshd':
ref_out = ref_out.transpose(1, 2).clone()
return ref_out
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
self.dtype)
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
self.dtype)
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
self.dtype)
result = self.fwd(q, k, v)
if self.input_metadata.o_scale is not None:
result, _ = scale_fp8(result, self.input_metadata.o_scale)
return result
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
k_descale, v_descale = (self.input_metadata.k_descale,
self.input_metadata.v_descale)
k_dequantized = (k_quantized.to(torch.float32) *
k_descale.to(torch.float32)).to(self.dtype)
v_dequantized = (v_quantized.to(torch.float32) *
v_descale.to(torch.float32)).to(self.dtype)
return self.fwd(q, k_dequantized, v_dequantized)
def varlen_fwd(self, q, k, v, is_mqa=False):
ref_out = torch.empty_like(q)
if is_mqa:
# Make KV look like HQ/HK "groups" of HK. Later, we will reshape so
# the size aligns with Q.
k_ref = k.view(k.shape[0], k.shape[1], 1,
k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1)
v_ref = v.view(v.shape[0], v.shape[1], 1,
v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1)
else:
k_ref = k
v_ref = v
for i in range(0, self.input_metadata.num_contexts):
start_q, start_k = self.input_metadata.cu_seqlens_q[
i], self.input_metadata.cu_seqlens_k[i]
end_q, end_k = self.input_metadata.cu_seqlens_q[
i + 1], self.input_metadata.cu_seqlens_k[i + 1]
k_curr = k_ref[start_k:end_k]
v_curr = v_ref[start_k:end_k]
if is_mqa:
k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3])
v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3])
scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q],
k_curr).float()
p = torch.softmax(scores * self.input_metadata.sm_scale,
dim=-1).half()
ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr)
return ref_out
def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False):
q_descale = None
if not fp8_kv:
q, q_descale = scale_fp8(q)
k, k_descale = scale_fp8(k)
v, v_descale = scale_fp8(v)
# In real world use case, the p scale would be a parameter trained by the
# model.
p_scale = None
o_scale = torch.rand(1, device="cuda",
requires_grad=False) if use_o_scale else None
return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale
def input_helper(
Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
layout=None,
use_alibi=None,
causal=None,
is_fp8=False,
fp8_kv=False,
use_o_scale=False,
use_bias=False,
):
assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
current_platform.seed_everything(0)
# Initialize q, k, v
if layout == 'bhsd':
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
elif layout == 'bshd':
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
if use_alibi:
# for n heads the set of slopes is the geometric sequence that starts
# 2^(-8/n)
alibi_slopes = torch.tensor(
[2**(-8 / HQ * i) for i in range(1, HQ + 1)],
dtype=torch.float32,
device="cuda").repeat(Z, 1)
else:
alibi_slopes = None
if use_bias:
bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K),
dtype=dtype,
device="cuda",
requires_grad=False)
else:
bias = None
q = torch.randn(q_tensor_shape,
dtype=dtype,
device="cuda",
requires_grad=False)
k = torch.randn(k_tensor_shape,
dtype=dtype,
device="cuda",
requires_grad=False)
v = torch.randn(k_tensor_shape,
dtype=dtype,
device="cuda",
requires_grad=False)
if is_fp8:
(q, k, v, q_descale, k_descale, v_descale, p_scale,
o_scale) = quantize_input(q,
k,
v,
use_o_scale=use_o_scale,
fp8_kv=fp8_kv)
else:
q_descale = k_descale = v_descale = p_scale = o_scale = None
input_metadata = MetaData(sm_scale=D_HEAD**-0.5,
max_seqlens_q=N_CTX_Q,
max_seqlens_k=N_CTX_K,
layout=layout,
alibi_slopes=alibi_slopes,
alibi_batch=Z,
alibi_nheads=HQ,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=o_scale,
bias=bias,
seqlen_q=N_CTX_Q,
seqlen_k=N_CTX_K)
return q, k, v, input_metadata
def varlen_input_helper(Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
equal_seqlens=False):
current_platform.seed_everything(0)
# Random sequence lengths. Using N_CTX as kind of max of sum of individual
# seqs
if not equal_seqlens:
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1,
max_seqlens_q + 1, (Z, ),
dtype=torch.int32)
seqlens_k = torch.randint(1,
max_seqlens_k + 1, (Z, ),
dtype=torch.int32)
else:
seqlens_q = torch.full((Z, ), N_CTX_Q // Z)
seqlens_k = torch.full((Z, ), N_CTX_K // Z)
# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat([
torch.tensor([0], dtype=torch.int32),
seqlens_q.cumsum(dim=0, dtype=torch.int32)
])
cu_seqlens_k = torch.cat([
torch.tensor([0], dtype=torch.int32),
seqlens_k.cumsum(dim=0, dtype=torch.int32)
])
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
# Initialize q, k, v with variable lengths
total_q = cu_seqlens_q[-1].item()
total_k = cu_seqlens_k[-1].item()
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype,
device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype,
device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype,
device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
return q, k, v, input_metadata
@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [
(1, 48, 12, 1, 1, 64),
(4, 4, 4, 128, 128, 65),
(16, 48, 48, 1, 1, 128),
(64, 48, 24, 3, 3, 128),
(4, 4, 4, 113, 123, 1),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_alibi', [True, False])
@pytest.mark.parametrize('layout', ['bshd'])
def test_op_fwd(Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
causal,
use_alibi,
layout,
dtype=torch.float16):
current_platform.seed_everything(0)
q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD,
dtype, layout, use_alibi, causal)
o = torch.empty_like(q)
# triton implementation
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
# Transpose here if layout is bshd so we have same reference code for all
# layouts
if layout == 'bshd':
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
# Replicate K and V if using MQA/GQA
if HQ != HK:
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
k.shape[3]).expand(-1, -1, HQ // HK, -1,
-1).reshape(k.shape[0], -1, k.shape[2],
k.shape[3])
v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
v.shape[3]).expand(-1, -1, HQ // HK, -1,
-1).reshape(v.shape[0], -1, v.shape[2],
v.shape[3])
ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD,
use_alibi, dtype, input_metadata)
ref_out = ref_impl.fwd(q, k, v)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('layout', ['bhsd'])
@pytest.mark.parametrize('use_o_scale', [True, False])
@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0),
reason="Triton FP8 requires CUDA 9.0 or higher")
def test_op_fwd_fp8(Z,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
causal,
layout,
use_o_scale,
dtype=torch.float32):
current_platform.seed_everything(0)
# Disable grad to save memory it won't run into OOM on CI machine.
# q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD,
# dtype, layout)
q_quantized, k_quantized, v_quantized, input_metadata = input_helper(
Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
causal=causal,
layout=layout,
is_fp8=True,
use_o_scale=use_o_scale)
o = torch.empty_like(q_quantized) if use_o_scale else None
tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized,
o, input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized)
# compare
torch.testing.assert_close(ref_out.to(torch.float32),
tri_out.to(torch.float32),
atol=7e-2,
rtol=2e-1)
@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
(4, 4, 113, 123, 1),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('layout', ['bhsd'])
def test_op_fwd_fp8_kv(Z,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
causal,
layout,
dtype=torch.float32):
current_platform.seed_everything(0)
q, k_quantized, v_quantized, input_metadata = input_helper(Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
causal=causal,
layout=layout,
is_fp8=True,
fp8_kv=True)
o = torch.empty_like(q)
tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o,
input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized)
torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1)
@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_bias', [True])
@pytest.mark.parametrize('dtype', [torch.bfloat16])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
current_platform.seed_everything(0)
q, k, v, input_metadata = input_helper(Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
layout='bhsd',
causal=causal,
use_bias=use_bias)
o = torch.empty_like(q)
# triton implementation
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.fwd(q, k, v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
# NOTE: Uses thd layout, so also tests thd.
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64),
(4, 48, 512, 64),
(16, 48, 512, 64),
(64, 48, 128, 128)])
@pytest.mark.parametrize('causal', [True, False])
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX,
D_HEAD, dtype)
tri_out = torch.empty_like(q)
triton_attention_rocm(q, k, v, tri_out, input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype,
input_metadata)
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
# NOTE: Uses thd layout, so also tests thd.
@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64),
(4, 48, 12, 256, 64),
(4, 48, 4, 512, 64),
(4, 64, 16, 128, 128)])
@pytest.mark.parametrize('causal', [False])
def test_op_varlen_mqa_fwd(Z,
HQ,
HK,
N_CTX,
D_HEAD,
causal,
dtype=torch.float16):
q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX,
D_HEAD, dtype)
tri_out = torch.empty_like(q)
triton_attention_rocm(q, k, v, tri_out, input_metadata)
ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Fused Attention Fused Attention
=============== ===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao This is a Triton implementation of the Flash Attention v2 algorithm
(https://tridao.me/publications/flash2/flash2.pdf) See https://tridao.me/publications/flash2/flash2.pdf
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported: Credits:
AMD Triton kernels team
OpenAI kernel team
1) Fwd with causal masking Currently only the forward kernel is supported, and contains these features:
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims 1) Fwd with causal masking
2) Arbitrary Q and KV sequence lengths
3) Arbitrary head sizes
4) Multi and grouped query attention
5) Variable sequence lengths
6) ALiBi and matrix bias
7) FP8 support
""" """
from typing import Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
torch_dtype: tl.constexpr = torch.float16 from vllm import _custom_ops as ops
from vllm.platforms import current_platform
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = current_platform.fp8_dtype()
default_float8_info = torch.finfo(default_eight_bit_dtype_torch)
FP8_MIN = triton.language.constexpr(default_float8_info.min)
# According to https://github.com/vllm-project/vllm/blob/main
# /csrc/quantization/utils.cuh#L31,
# need to make the max for the uz datatype be 224.0 for accuracy reasons.
FP8_MAX = triton.language.constexpr(
default_float8_info.max if default_eight_bit_dtype_torch !=
torch.float8_e4m3fnuz else 224.0)
class MetaData:
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
eight_bit = False
layout = None
return_encoded_softmax = False
eight_bit_dtype_triton = default_eight_bit_dtype_triton
eight_bit_dtype_torch = default_eight_bit_dtype_torch
output_dtype = None
# Note about layouts:
#
# thd - [num_tokens, num_heads, head_size]
# bshd - [batch_size, seq_len, num_heads, head_size]
# bhsd - [batch_size, num_heads, seq_len, head_size]
#
# This is for each tensor, all tensors must have same layout.
# Q can have num_heads and seq_len differ from from K and V,
# however K and V must agree on this.
#
# Notes about varlen and bias:
# Only one or the other is implemented, meaning can't combine
# both varlen and bias right now.
#
# Note about quantization:
# Only 8-bit quantization supported (for now) and specifically fp8.
# Scales must be tensors.
# o_scale: This is 'output scaling', but comes from parameter called
# 'input_scale', this is applied to the output from the kernel.
# o_scale should be None if none of the other quantization parameters
# are used.
#
# NOTE: Object is in a tentatively good state after initialized, however,
# to verify, call check_args(q,k,v,o) where o is the output tensor.
def __init__(
self,
sm_scale=1.0,
layout=None, # layout can be 'bshd', 'bhsd', or 'thd'
output_dtype=None,
max_seqlens_q=0,
max_seqlens_k=0,
# varlen params
cu_seqlens_q=None, # only 'thd' layout supported for varlen
cu_seqlens_k=None,
# quant params
q_descale=None,
k_descale=None,
v_descale=None,
p_scale=None,
o_scale=None,
# bias params
bias=None, # varlen not implemented for bias
seqlen_q=None,
seqlen_k=None,
# alibi params
alibi_slopes=None,
alibi_batch=None,
alibi_nheads=None,
# causal
causal=None,
):
self.sm_scale = sm_scale
self.output_dtype = output_dtype
self.max_seqlens_q = max_seqlens_q
self.max_seqlens_k = max_seqlens_k
self.layout = layout
if cu_seqlens_q is not None or cu_seqlens_k is not None:
assert cu_seqlens_q is not None and cu_seqlens_k is not None
assert layout is None or layout not in [
'bshd', 'bhsd'
], "Varlen only implemented for thd layout"
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale]
if any(x is not None for x in quant_params):
p_descale = 1.0 / p_scale if p_scale is not None else None
self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale)
if bias is not None:
self.need_bias(bias, seqlen_q, seqlen_k)
if alibi_slopes is not None:
self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads)
if causal is not None and causal:
self.need_causal()
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(
cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(),
self.max_seqlens_q)
self.max_seqlens_k = max(
cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(),
self.max_seqlens_k)
def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale):
self.eight_bit = True
self.q_descale = q_descale
self.k_descale = k_descale
self.v_descale = v_descale
self.p_scale = p_scale
self.p_descale = p_descale
self.o_scale = o_scale
self.use_p_scale = (p_scale is not None) and (
p_descale is not None) and (v_descale is not None)
self.eight_bit_kv = ((q_descale is None) and (k_descale is not None)
and (v_descale is not None))
self.eight_bit_dtype_torch = default_eight_bit_dtype_torch
def need_bias(self, bias, seqlen_q, seqlen_k):
assert bias is not None
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, self)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
if self.eight_bit:
if self.eight_bit_kv:
assert (v.dtype == k.dtype
and k.dtype == self.eight_bit_dtype_torch)
assert q.dtype != k.dtype
assert (self.v_descale is not None) and (self.k_descale
is not None)
else:
assert (q.dtype == k.dtype and q.dtype == v.dtype
and q.dtype == self.eight_bit_dtype_torch)
assert (self.q_descale
is not None) and (self.k_descale
is not None) and (self.v_descale
is not None)
if self.use_p_scale:
assert (self.p_scale is not None) and (self.p_descale
is not None)
else:
assert (q.dtype == k.dtype) and (q.dtype == v.dtype)
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
@triton.jit @triton.jit
...@@ -38,40 +244,85 @@ def max_fn(x, y): ...@@ -38,40 +244,85 @@ def max_fn(x, y):
return tl.math.max(x, y) return tl.math.max(x, y)
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit @triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): def masked_load(ptrs, offset_first, offset_second, boundary_first,
ms = tl.arange(0, m) boundary_second):
ns = tl.arange(0, n) if offset_first is not None and offset_second is not None:
return philox_offset + ms[:, None] * stride + ns[None, :] mask = (offset_first[:, None] < boundary_first) & \
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit @triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): def compute_alibi_block(alibi_slope,
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, seqlen_q,
stride).to(tl.uint32) seqlen_k,
# TODO: use tl.randint for better performance offs_m,
return tl.rand(philox_seed, rng_offsets) offs_n,
transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to
# the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is
# masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that
# spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5,
# offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] =
# [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q -
offs_n[None, :])
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k):
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): q_idx = torch.arange(seqlen_q, dtype=torch.int32,
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
stride) k_idx = torch.arange(seqlen_k, dtype=torch.int32,
rng_keep = rng_output > dropout_p device="cuda").unsqueeze(0) # (1, N_CTX_K)
return rng_keep relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q -
k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
@triton.jit @triton.jit
def load_fn(block_ptr, first, second, pad): def quant_fp8(x, scale):
if first and second: x *= scale
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) x = tl.clamp(x, FP8_MIN, FP8_MAX)
elif first: return x
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit @triton.jit
...@@ -80,58 +331,68 @@ def _attn_fwd_inner( ...@@ -80,58 +331,68 @@ def _attn_fwd_inner(
l_i, l_i,
m_i, m_i,
q, q,
K_block_ptr, k_ptrs,
V_block_ptr, v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
actual_seqlen_k, actual_seqlen_k,
dropout_p, actual_seqlen_q,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_softmax_block_ptr, encoded_sm_ptrs,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
bias_ptr, alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr, OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, SHOULD_PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr, SHOULD_MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_PADDED_HEAD: tl.constexpr,
PADDED_HEAD: tl.constexpr, IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
QK_SCALE: tl.constexpr,
IS_EIGHT_BIT_GEMM: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
# loop over k, v, and update accumulator # loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N): for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if # For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range. # we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn( k_offs_n = start_n + tl.arange(0,
K_block_ptr, BLOCK_N) if SHOULD_MASK_STEPS else None
PADDED_HEAD, k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0), k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL,
"zero", actual_seqlen_k)
) if SHOULD_PRE_LOAD_V:
if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed.
v = load_fn( v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
V_block_ptr, IS_ACTUAL_BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need # We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n # to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block. # TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102 if SHOULD_MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to # If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size # mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not
# if not is_modulo_mn. last step might get wasted but that is okay. # is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case. # check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], boundary_m = tl.full([BLOCK_M],
...@@ -144,167 +405,276 @@ def _attn_fwd_inner( ...@@ -144,167 +405,276 @@ def _attn_fwd_inner(
causal_boundary = start_n + offs_n_causal causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf")) qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ---- # -- compute qk ----
qk += tl.dot(q, k) if IS_EIGHT_BIT_GEMM:
if bias_ptr is not None: qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) *
bias = load_fn(bias_ptr, False, MASK_STEPS QK_SCALE)
and (n_extra_tokens != 0), "zero") else:
# While bias is added after multiplying qk with sm_scale, our if IS_EIGHT_BIT_KV:
# optimization to use 2^x instead of e^x results in an additional k = (k * k_descale).to(q.type.element_ty)
# scale factor of log2(e) which we must also multiply the bias with. qk += (tl.dot(q, k) * QK_SCALE)
qk += bias * 1.44269504089
if bias_ptrs is not None:
bias_offs_n = start_n + tl.arange(
0, BLOCK_N) if SHOULD_MASK_STEPS else None
bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q,
actual_seqlen_k)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an
# additional scale factor of log2(e) which we must also multiply
# the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q,
actual_seqlen_k,
global_m_positions,
global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None] qk = qk - m_ij[:, None]
p = tl.math.exp2(qk) p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout # CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT: if SHOULD_RETURN_ENCODED_SOFTMAX:
philox_offset = (batch_philox_offset + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty))
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator -- # -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij) alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
if not PRE_LOAD_V: if not SHOULD_PRE_LOAD_V:
v = load_fn( v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
V_block_ptr, IS_ACTUAL_BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i # -- update m_i and l_i
l_i = l_i * alpha + l_ij l_i = l_i * alpha + l_ij
# update m_i and l_i # update m_i and l_i
m_i = m_ij m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) if IS_EIGHT_BIT_GEMM:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if USE_P_SCALE:
if bias_ptr is not None: p = quant_fp8(p, p_scale).to(QUANT_DTYPE)
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) acc += tl.dot(p, v)
if RETURN_ENCODED_SOFTMAX: else:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, # v is in eight_bit but p is not, we want the gemm in p's type
(0, BLOCK_N)) acc += tl.dot(p, v.to(p.type.element_ty))
else:
if IS_EIGHT_BIT_KV:
v = (v * v_descale).to(p.type.element_ty)
acc += tl.dot(p.to(v.type.element_ty), v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += BLOCK_N
return acc, l_i, m_i return acc, l_i, m_i
@triton.autotune( def get_cdna_autotune_configs():
configs=[ return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config( triton.Config(
{ {
"BLOCK_M": 256, 'BLOCK_M': 128,
"BLOCK_N": 64, 'BLOCK_N': 64,
"waves_per_eu": 2, 'waves_per_eu': 1,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=4),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 128,
"BLOCK_N": 128, 'BLOCK_N': 32,
"waves_per_eu": 2, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=4),
), ], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
"BLOCK_M": 256, 'BLOCK_M': 32,
"BLOCK_N": 128, 'BLOCK_N': 32,
"waves_per_eu": 2, 'waves_per_eu': 4,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 32,
"waves_per_eu": 1, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 16,
"waves_per_eu": 3, 'waves_per_eu': 4,
"PRE_LOAD_V": True, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 16,
"waves_per_eu": 3, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 64, 'BLOCK_M': 16,
"BLOCK_N": 64, 'BLOCK_N': 16,
"waves_per_eu": 4, 'waves_per_eu': 4,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 32, 'BLOCK_M': 16,
"BLOCK_N": 32, 'BLOCK_N': 16,
"waves_per_eu": 4, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
), # Fall-back config.
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config( triton.Config(
{ {
"BLOCK_M": 16, 'BLOCK_M': 16,
"BLOCK_N": 16, 'BLOCK_N': 16,
"waves_per_eu": 1, 'waves_per_eu': 1,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
), ], [
], 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_general_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 32,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def has_cdna_target():
ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_CDNA_TARGETS
def is_rocm_cdna():
return current_platform.is_rocm() and has_cdna_target()
def get_autotune_configs():
if is_rocm_cdna():
return get_cdna_autotune_configs()
elif current_platform.is_rocm():
return get_rdna_autotune_configs()
else:
return get_general_autotune_configs()
autotune_configs, autotune_keys = get_autotune_configs()
@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
) )
@triton.jit @triton.jit
def attn_fwd( def attn_fwd(
...@@ -312,38 +682,53 @@ def attn_fwd( ...@@ -312,38 +682,53 @@ def attn_fwd(
K, K,
V, V,
bias, bias,
sm_scale, SM_SCALE: tl.constexpr,
L, L,
Out, Out,
stride_qz, stride_qz: tl.int64,
stride_qh, stride_qh: tl.int64,
stride_qm, stride_qm: tl.int64,
stride_qk, stride_qk: tl.int64,
stride_kz, stride_kz: tl.int64,
stride_kh, stride_kh: tl.int64,
stride_kn, stride_kn: tl.int64,
stride_kk, stride_kk: tl.int64,
stride_vz, stride_vz: tl.int64,
stride_vh, stride_vh: tl.int64,
stride_vk, stride_vk: tl.int64,
stride_vn, stride_vn: tl.int64,
stride_oz, stride_oz: tl.int64,
stride_oh, stride_oh: tl.int64,
stride_om, stride_om: tl.int64,
stride_on, stride_on: tl.int64,
stride_bz, stride_bz: tl.int64,
stride_bh, stride_bh: tl.int64,
stride_bm, stride_bm: tl.int64,
stride_bn, stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
dropout_p,
philox_seed, philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
alibi_slopes,
HQ: tl.constexpr, HQ: tl.constexpr,
HK: tl.constexpr, HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, VARLEN: tl.constexpr,
...@@ -351,24 +736,39 @@ def attn_fwd( ...@@ -351,24 +736,39 @@ def attn_fwd(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, SHOULD_PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1) if o_descale_ptr is not None:
off_z = tl.program_id(2) o_descale = tl.load(o_descale_ptr)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) start_m: tl.int64 = tl.program_id(0)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN: if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too # We have a one-size-fits-all grid in id(0). Some seqlens might be
# small for all start_m so for those we return early. # too small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q: if start_m * BLOCK_M > seqlen_q:
return continue_condition = False
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
...@@ -378,444 +778,598 @@ def attn_fwd( ...@@ -378,444 +778,598 @@ def attn_fwd(
seqlen_q = MAX_SEQLENS_Q seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking. if continue_condition:
# This is because for seqlen_q > seqlen_k, M rows of the attn scores # Now we compute whether we need to exit early due to causal
# are completely masked, resulting in 0s written to the output, and # masking. This is because for seqlen_q > seqlen_k, M rows of the
# inf written to LSE. We don't need to do any GEMMs in this case. # attn scores are completely masked, resulting in 0s written to the
# This block of code determines what N is, and if this WG is operating # output, and inf written to LSE. We don't need to do any GEMMs in
# on those M rows. # this case. This block of code determines what N is, and if this
n_blocks = cdiv_fn(seqlen_k, BLOCK_N) # WG is operating on those M rows.
if IS_CAUSAL: n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
# If seqlen_q == seqlen_k, the attn scores are a square matrix. if (IS_CAUSAL):
# If seqlen_q != seqlen_k, attn scores are rectangular which means # If seqlen_q == seqlen_k, the attn scores are a square matrix.
# the causal mask boundary is bottom right aligned, and ends at either # If seqlen_q != seqlen_k, attn scores are rectangular which
# the top edge (seqlen_q < seqlen_k) or left edge. # means the causal mask boundary is bottom right aligned, and
# This captures the decrease in n_blocks if we have a rectangular attn # ends at either the top edge (seqlen_q < seqlen_k) or left
# matrix # edge. This captures the decrease in n_blocks if we have a
n_blocks_seqlen = cdiv_fn( # rectangular attn matrix
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) n_blocks_seqlen = cdiv_fn(
# This is what adjusts the block_max for the current WG, only (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks # This is what adjusts the block_max for the current WG, only
n_blocks = min(n_blocks, n_blocks_seqlen) # if IS_CAUSAL. Otherwise we want to always iterate through all
# If we have no blocks after adjusting for seqlen deltas, this WG is # n_blocks
# part of the blocks that are all 0. We exit early. n_blocks = min(n_blocks, n_blocks_seqlen)
if n_blocks <= 0: # If we have no blocks after adjusting for seqlen deltas, this
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + # WG is part of the blocks that are all 0. We exit early.
off_h_q * stride_oh) if n_blocks <= 0:
O_block_ptr = tl.make_block_ptr( o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
base=Out + o_offset, cu_seqlens_q_start * stride_om)
shape=(seqlen_q, BLOCK_DMODEL), o_ptrs = (o_offset + offs_m[:, None] * stride_om +
strides=(stride_om, stride_on), offs_d[None, :] * stride_on)
offsets=(start_m * BLOCK_M, 0), acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_shape=(BLOCK_M, BLOCK_DMODEL), o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to(
order=(1, 0), [BLOCK_M, BLOCK_DMODEL])
) # We still need to write 0s to the result
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) tl.store(o_ptrs, acc, mask=o_ptrs_mask)
# We still need to write 0s to the result # The tensor allocated for L is based on MAX_SEQLENS_Q as
# tl.store(O_block_ptr, # that is statically known.
# acc.to(Out.type.element_ty), boundary_check=(0,1)) l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q off_h_q * MAX_SEQLENS_Q + offs_m)
# + offs_m # We store inf to LSE, not -inf because in the bwd pass,
# We store inf to LSE, not -inf because in the bwd pass, # we subtract this from qk which makes it -inf, such that
# we subtract this # exp(qk - inf) = 0 for these masked blocks.
# from qk which makes it -inf, such that exp(qk - inf) = 0 l_value = tl.full([BLOCK_M],
# for these masked blocks. value=float("inf"),
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) dtype=tl.float32)
# tl.store(l_ptrs, l) l_ptrs_mask = offs_m < MAX_SEQLENS_Q
# TODO: Should dropout and return encoded softmax be handled here? tl.store(l_ptrs, l_value, mask=l_ptrs_mask)
return # TODO: Should dropout and return encoded softmax be
# handled here too?
# If MQA / GQA, set the K and V head offsets appropriately. continue_condition = False
GROUP_SIZE: tl.constexpr = HQ // HK # return
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
if continue_condition:
n_extra_tokens = 0 # If MQA / GQA, set the K and V head offsets appropriately.
if seqlen_k < BLOCK_N: GROUP_SIZE: tl.constexpr = HQ // HK
n_extra_tokens = BLOCK_N - seqlen_k off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
elif seqlen_k % BLOCK_N: n_extra_tokens = 0
n_extra_tokens = seqlen_k % BLOCK_N if seqlen_k < BLOCK_N:
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
# Compute pointers for all the tensors used in this kernel. n_extra_tokens = seqlen_k % BLOCK_N
q_offset = (off_z * stride_qz + off_h_q * stride_qh + USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL
cu_seqlens_q_start * stride_qm) != BLOCK_DMODEL)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset, # Compute pointers for all the tensors used in this kernel.
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh +
strides=(stride_qm, stride_qk), cu_seqlens_q_start * stride_qm)
offsets=(start_m * BLOCK_M, 0), q_ptrs = (q_offset + offs_m[:, None] * stride_qm +
block_shape=(BLOCK_M, BLOCK_DMODEL), offs_d[None, :] * stride_qk)
order=(1, 0), k_offset = (K + off_z * stride_kz + off_h_k * stride_kh +
) cu_seqlens_k_start * stride_kn)
k_offset = (off_z * stride_kz + off_h_k * stride_kh + k_ptrs = (k_offset + offs_d[:, None] * stride_kk +
cu_seqlens_k_start * stride_kn) offs_n[None, :] * stride_kn)
K_block_ptr = tl.make_block_ptr( v_offset = (V + off_z * stride_vz + off_h_k * stride_vh +
base=K + k_offset, cu_seqlens_k_start * stride_vk)
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), v_ptrs = (v_offset + offs_n[:, None] * stride_vk +
strides=(stride_kk, stride_kn), offs_d[None, :] * stride_vn)
offsets=(0, 0), # Compute pointers for all scale tensors used in this kernel.
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1), IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & (
) not IS_EIGHT_BIT_KV)
v_offset = (off_z * stride_vz + off_h_k * stride_vh + if IS_EIGHT_BIT:
cu_seqlens_k_start * stride_vk) if k_descale_has_singleton:
V_block_ptr = tl.make_block_ptr( k_descale_ptrs = k_descale_ptr
base=V + v_offset, else:
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), k_descale_ptrs = k_descale_ptr + off_h_k
strides=(stride_vk, stride_vn),
offsets=(0, 0), if v_descale_has_singleton:
block_shape=(BLOCK_N, BLOCK_DMODEL), v_descale_ptrs = v_descale_ptr
order=(1, 0), else:
) v_descale_ptrs = v_descale_ptr + off_h_k
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr( if not IS_EIGHT_BIT_KV:
base=bias + off_h_q * stride_bh, if q_descale_has_singleton:
shape=(seqlen_q, seqlen_k), q_descale_ptrs = q_descale_ptr
strides=(stride_bm, stride_bn), else:
offsets=(start_m * BLOCK_M, 0), q_descale_ptrs = q_descale_ptr + off_h_q
block_shape=(BLOCK_M, BLOCK_N), if USE_P_SCALE:
order=(1, 0), if p_descale_has_singleton:
) p_scale_ptrs = p_scale_ptr
else: p_descale_ptrs = p_descale_ptr
bias_ptr = None else:
if ENABLE_DROPOUT: p_scale_ptrs = p_scale_ptr + off_h_q
batch_philox_offset = philox_offset_base \ p_descale_ptrs = p_descale_ptr + off_h_q
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k if USE_BIAS:
else: bias_offset = off_h_q * stride_bh
batch_philox_offset = 0 bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm +
# We can ask to return the dropout mask without actually doing any dropout. offs_n[None, :] * stride_bn)
# In this case, we return an invalid pointer so indicate the mask is not i else:
# valid. bias_ptrs = None
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX: if USE_ALIBI:
encoded_softmax_block_ptr = tl.make_block_ptr( a_offset = off_z * stride_az + off_h_q * stride_ah
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, alibi_slope = tl.load(alibi_slopes + a_offset)
shape=(seqlen_q, seqlen_k), else:
strides=(seqlen_k, 1), alibi_slope = None
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N), batch_philox_offset = 0
order=(1, 0), # We can ask to return the dropout mask without doing any
) # dropout. In this case, we return an invalid pointer so
else: # indicate the mask is not valid.
encoded_softmax_block_ptr = 0 if SHOULD_RETURN_ENCODED_SOFTMAX:
# initialize pointer to m and l encoded_sm_base = (encoded_softmax +
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) off_h_q * seqlen_q * seqlen_k)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) encoded_sm_ptrs = (encoded_sm_base +
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) offs_m[:, None] * seqlen_k +
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not offs_n[None, :])
# have native e^x support in HW. else:
qk_scale = sm_scale * 1.44269504089 encoded_sm_ptrs = None
# Q is loaded once at the beginning and shared by all N blocks. # initialize pointer to m and l
q = load_fn(Q_block_ptr, True, padded_head, "zero") m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
q = (q * qk_scale).to(Q_block_ptr.type.element_ty) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# Here we compute how many full and masked blocks we have. # scale sm_scale by log_2(e) and use 2^x in the loop as we do
padded_block_k = n_extra_tokens != 0 # not have native e^x support in HW.
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089
if IS_CAUSAL: # Q is loaded once at the beginning and shared by all N blocks.
# There are always at least BLOCK_M // BLOCK_N masked blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q
# Additionally there might be one more due to dissimilar seqlens. if USE_PADDED_HEAD:
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) q_ptrs_mask = q_ptrs_mask & (offs_d[None, :]
else: < IS_ACTUAL_BLOCK_DMODEL)
# Padding on Q does not need to be masked in the FA loop. q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional if IS_EIGHT_BIT:
# block. In this case we might exceed n_blocks so pick the min. k_descale = tl.load(k_descale_ptrs)
masked_blocks = min(masked_blocks, n_blocks) v_descale = tl.load(v_descale_ptrs)
n_full_blocks = n_blocks - masked_blocks q_descale = None if IS_EIGHT_BIT_KV else tl.load(
block_min = 0 q_descale_ptrs)
block_max = n_blocks * BLOCK_N if USE_P_SCALE:
# Compute for full blocks. Here we set causal to false regardless of its p_scale = tl.load(p_scale_ptrs)
# value because there is no masking. Similarly we do not need padding. p_descale = tl.load(p_descale_ptrs)
if n_full_blocks > 0: else:
block_max = (n_blocks - masked_blocks) * BLOCK_N p_scale = None
acc, l_i, m_i = _attn_fwd_inner( p_descale = None
acc, else:
l_i, q_descale = None
m_i, k_descale = None
q, v_descale = None
K_block_ptr, p_scale = None
V_block_ptr, p_descale = None
start_m, # Here we compute how many full and masked blocks we have.
seqlen_k, padded_block_k = n_extra_tokens != 0
dropout_p, is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
philox_seed, if IS_CAUSAL:
batch_philox_offset, # There are always at least BLOCK_M // BLOCK_N masked
encoded_softmax_block_ptr, # blocks. Additionally there might be one more due to
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ # dissimilar seqlens.
block_min, masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
block_max, else:
0, # Padding on Q does not need to be masked in the FA loop.
0, masked_blocks = padded_block_k
0, # if IS_CAUSAL, not is_modulo_mn does not always result in an
bias_ptr, # additional block. In this case we might exceed n_blocks so
# IS_CAUSAL, .... # pick the min.
False, masked_blocks = min(masked_blocks, n_blocks)
BLOCK_M, n_full_blocks = n_blocks - masked_blocks
BLOCK_DMODEL, block_min = 0
BLOCK_N, block_max = n_blocks * BLOCK_N
offs_m, # Compute for full blocks. Here we set causal to false
offs_n, # regardless of its actual value because there is no masking.
# _, MASK_STEPS, ... # Similarly we do not need padding.
PRE_LOAD_V, if n_full_blocks > 0:
False, block_max = (n_blocks - masked_blocks) * BLOCK_N
ENABLE_DROPOUT, acc, l_i, m_i = _attn_fwd_inner(
RETURN_ENCODED_SOFTMAX, acc,
padded_head, l_i,
) m_i,
block_min = block_max q,
block_max = n_blocks * BLOCK_N k_ptrs,
v_ptrs,
tl.debug_barrier() bias_ptrs,
# Remaining blocks, if any, are full / not masked. stride_kn,
if masked_blocks > 0: stride_vk,
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 stride_bn,
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) start_m,
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) seqlen_k,
if bias_ptr is not None: seqlen_q,
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) philox_seed,
if RETURN_ENCODED_SOFTMAX: batch_philox_offset,
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, encoded_sm_ptrs,
(0, n_full_blocks)) # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
acc, l_i, m_i = _attn_fwd_inner( block_min,
acc, block_max,
l_i, 0,
m_i, 0,
q, 0,
K_block_ptr, alibi_slope,
V_block_ptr, q_descale,
start_m, k_descale,
seqlen_k, v_descale,
dropout_p, p_scale,
philox_seed, # IS_CAUSAL, ....
batch_philox_offset, False,
encoded_softmax_block_ptr, BLOCK_M,
block_min, BLOCK_DMODEL,
block_max, BLOCK_N,
offs_n_causal, offs_m,
masked_blocks, offs_n,
n_extra_tokens, # _, SHOULD_MASK_STEPS, ...
bias_ptr, SHOULD_PRE_LOAD_V,
IS_CAUSAL, False,
BLOCK_M, SHOULD_RETURN_ENCODED_SOFTMAX,
BLOCK_DMODEL, USE_PADDED_HEAD,
BLOCK_N, IS_ACTUAL_BLOCK_DMODEL,
offs_m, QK_SCALE,
offs_n, IS_EIGHT_BIT_GEMM,
# _, MASK_STEPS, ... USE_P_SCALE,
PRE_LOAD_V, IS_EIGHT_BIT_KV,
True, QUANT_DTYPE)
ENABLE_DROPOUT, block_min = block_max
RETURN_ENCODED_SOFTMAX, block_max = n_blocks * BLOCK_N
padded_head,
) tl.debug_barrier()
# epilogue # Remaining blocks, if any, are full / not masked.
acc = acc / l_i[:, None] if (masked_blocks > 0):
if ENABLE_DROPOUT: if IS_CAUSAL:
acc = acc / (1 - dropout_p) offs_n_causal = offs_n + (seqlen_q - seqlen_k)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, else:
# then we have one block with a row of all NaNs which come from computing offs_n_causal = 0
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here k_ptrs += n_full_blocks * BLOCK_N * stride_kn
# and store 0s where there are NaNs as these rows should've been zeroed out. v_ptrs += n_full_blocks * BLOCK_N * stride_vk
end_m_idx = (start_m + 1) * BLOCK_M if USE_BIAS:
start_m_idx = start_m * BLOCK_M bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
causal_start_idx = seqlen_q - seqlen_k if SHOULD_RETURN_ENCODED_SOFTMAX:
acc = acc.to(Out.type.element_ty) encoded_sm_ptrs += n_full_blocks * BLOCK_N
if IS_CAUSAL: # noqa: SIM102 acc, l_i, m_i = _attn_fwd_inner(
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: acc,
out_mask_boundary = tl.full((BLOCK_DMODEL, ), l_i,
causal_start_idx, m_i,
dtype=tl.int32) q,
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) k_ptrs,
out_ptrs_mask = (mask_m_offsets[:, None] v_ptrs,
>= out_mask_boundary[None, :]) bias_ptrs,
z = 0.0 stride_kn,
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) stride_vk,
# write back LSE stride_bn,
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m start_m,
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last seqlen_k,
# few rows. This is only true for the last M block. For others, seqlen_q,
# overflow_size will be -ve philox_seed,
# overflow_size = end_m_idx - seqlen_q batch_philox_offset,
# if overflow_size > 0: encoded_sm_ptrs,
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) block_min,
# # This is a > check because mask being 0 blocks the store. block_max,
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) offs_n_causal,
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) masked_blocks,
# else: n_extra_tokens,
# tl.store(l_ptrs, m_i + tl.math.log2(l_i)) alibi_slope,
q_descale,
# write back O k_descale,
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + v_descale,
off_h_q * stride_oh) p_scale,
O_block_ptr = tl.make_block_ptr( IS_CAUSAL,
base=Out + o_offset, BLOCK_M,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), BLOCK_DMODEL,
strides=(stride_om, stride_on), BLOCK_N,
offsets=(start_m * BLOCK_M, 0), offs_m,
block_shape=(BLOCK_M, BLOCK_DMODEL), offs_n,
order=(1, 0), # _, SHOULD_MASK_STEPS, ...
) SHOULD_PRE_LOAD_V,
# Need boundary check on this to make sure the padding from the True,
# Q and KV tensors in both dims are not part of what we store back. SHOULD_RETURN_ENCODED_SOFTMAX,
# TODO: Do the boundary check optionally. USE_PADDED_HEAD,
tl.store(O_block_ptr, acc, boundary_check=(0, 1)) IS_ACTUAL_BLOCK_DMODEL,
QK_SCALE,
IS_EIGHT_BIT_GEMM,
def check_args( USE_P_SCALE,
q, IS_EIGHT_BIT_KV,
k, QUANT_DTYPE)
v,
o, if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV:
varlen=True, if USE_P_SCALE:
max_seqlens=None, acc *= p_descale
cu_seqlens_q=None, acc *= v_descale
cu_seqlens_k=None,
): # epilogue
assert q.dim() == k.dim() and q.dim() == v.dim() # This helps the compiler do Newton Raphson on l_i vs on acc
if varlen: # which is much larger.
assert q.dim() == 3 l_recip = 1 / l_i[:, None]
total_q, nheads_q, head_size = q.shape acc = acc * l_recip
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None # If seqlen_q > seqlen_k but the delta is not a multiple of
assert cu_seqlens_k is not None # BLOCK_M, then we have one block with a row of all NaNs which
assert len(cu_seqlens_q) == len(cu_seqlens_k) # come from computing softmax over a row of all
else: # -infs (-inf - inf = NaN). We check for that here and store 0s
assert q.dim() == 4 # where there are NaNs as these rows should've been zeroed out.
batch, nheads_q, seqlen_q, head_size = q.shape end_m_idx = (start_m + 1) * BLOCK_M
_, nheads_k, seqlen_k, _ = k.shape start_m_idx = start_m * BLOCK_M
assert max_seqlens > 0 causal_start_idx = seqlen_q - seqlen_k
assert k.shape == v.shape if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] if o_descale_ptr is not None:
# TODO: Change assert if we support qkl f8 and v f16 acc = quant_fp8(acc, o_descale)
assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256 acc = acc.to(Out.type.element_ty)
assert o.shape == q.shape if IS_CAUSAL: # noqa: SIM102
assert (nheads_q % nheads_k) == 0 if (causal_start_idx > start_m_idx
and causal_start_idx < end_m_idx):
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :])
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc,
z.to(acc.type.element_ty))
# write back LSE
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# If seqlen_q not multiple of BLOCK_M, we need to mask out the
# last few rows. This is only true for the last M block.
# For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M, ),
BLOCK_M - overflow_size,
dtype=tl.int32)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
if overflow_size > 0:
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
if USE_PADDED_HEAD:
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
def get_shape_from_layout(q, k, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
if metadata.layout == 'thd':
nheads_q, nheads_k = q.shape[1], k.shape[1]
head_size = q.shape[-1]
batch = metadata.num_contexts
elif metadata.layout == 'bhsd':
batch, nheads_q, _, head_size = q.shape
nheads_k = k.shape[1]
elif metadata.layout == 'bshd':
batch, _, nheads_q, head_size = q.shape
nheads_k = k.shape[2]
return batch, nheads_q, nheads_k, head_size
def get_strides_from_layout(q, k, v, o, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
STRIDE_PERMUTATIONS = {
'thd': (None, 1, 0, 2),
'bhsd': (0, 1, 2, 3),
'bshd': (0, 2, 1, 3),
}
perm = STRIDE_PERMUTATIONS[metadata.layout]
stride = lambda x, p: (0 if p is None else x.stride(p))
strides = lambda x: (stride(x, p) for p in perm)
return tuple(strides(x) for x in [q, k, v, o])
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(ctx, q, k, v, o, metadata: MetaData):
ctx, # NOTE: a large bias tensor leads to overflow during pointer arithmetic
q, if (metadata.bias is not None):
k, assert (metadata.bias.numel() < 2**31)
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None: if o is None:
o = torch.empty_like(q, dtype=v.dtype) if metadata.eight_bit:
o = torch.empty_like(
q,
dtype=metadata.output_dtype if metadata.output_dtype
is not None else metadata.eight_bit_dtype_torch)
else:
o = torch.empty_like(q, dtype=q.dtype)
check_args( metadata.check_args(q, k, v, o)
q,
k, batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
v, q, k, metadata)
o, q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(
varlen=True, q, k, v, o, metadata)
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32. # Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128, 256} padded_d_model = 1 << (head_size - 1).bit_length()
if head_size not in unpadded_head_dims: # Smallest head_dim supported is 16. If smaller, the tile in the
padded_d_model = None # kernel is padded - there is no padding in memory for any dims.
for i in unpadded_head_dims: padded_d_model = max(padded_d_model, 16)
if i > head_size:
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
grid = lambda META: ( # encoded_softmax is used to validate dropout behavior vs the
triton.cdiv(max_seqlens_q, META["BLOCK_M"]), # PyTorch SDPA math backend reference. We zero this out to give a
nheads_q, # consistent starting point and then populate it with the output of
batch, # softmax with the sign bit set according to the dropout mask.
) # The resulting return allows this mask to be fed into the reference
# implementation for testing only. This return holds no useful output
# aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
device=q.device,
dtype=torch.float32)
else:
encoded_softmax = None
encoded_softmax = None M = torch.empty((batch, nheads_q, metadata.max_seqlens_q),
device=q.device,
dtype=torch.float32)
# Seed the RNG so we get reproducible results for testing. # Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52 philox_seed = 0x1BF52
philox_offset = 0x1D4B42 philox_offset = 0x1D4B42
if bias is not None: if metadata.bias is not None:
bias_strides = ( bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
bias.stride(0), metadata.bias.stride(2), metadata.bias.stride(3))
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else: else:
bias_strides = (0, 0, 0, 0) bias_strides = (0, 0, 0, 0)
if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0),
metadata.alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
if metadata.eight_bit:
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = (
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
else:
q_descale = k_descale = p_scale = None
p_descale = v_descale = o_descale = None
# number of compute units available
NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[
'BLOCK_M']), nheads_q, batch)
attn_fwd[grid]( attn_fwd[grid](
q, q,
k, k,
v, v,
bias, metadata.bias,
sm_scale, metadata.sm_scale,
None, M,
o, o,
*q_strides, *q_strides,
*k_strides, *k_strides,
*v_strides, *v_strides,
*o_strides, *o_strides,
*bias_strides, *bias_strides,
cu_seqlens_q, *alibi_strides,
cu_seqlens_k, q_descale,
dropout_p=0.0, k_descale,
p_scale,
p_descale,
o_descale,
v_descale,
q_descale.numel() == 1 if q_descale is not None else False,
k_descale.numel() == 1 if k_descale is not None else False,
p_descale.numel() == 1 if p_descale is not None else False,
v_descale.numel() == 1 if v_descale is not None else False,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
philox_seed=philox_seed, philox_seed=philox_seed,
philox_offset_base=philox_offset, philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax, encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HQ=nheads_q,
HK=nheads_k, HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size, IS_ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, MAX_SEQLENS_K=metadata.max_seqlens_k,
IS_CAUSAL=causal, IS_CAUSAL=metadata.causal,
VARLEN=True, VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model, BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1, USE_BIAS=metadata.bias is not None,
ENABLE_DROPOUT=False, USE_ALIBI=metadata.alibi_slopes is not None,
RETURN_ENCODED_SOFTMAX=False, SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
) IS_EIGHT_BIT=metadata.eight_bit,
USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale,
IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv,
NUM_CU=NUM_CU,
B=batch,
QUANT_DTYPE=metadata.eight_bit_dtype_triton)
ctx.grid = grid ctx.grid = grid
ctx.sm_scale = sm_scale ctx.sm_scale = metadata.sm_scale
ctx.BLOCK_DMODEL = head_size ctx.BLOCK_DMODEL = head_size
ctx.causal = causal ctx.causal = metadata.causal
ctx.dropout_p = 0.0 ctx.alibi_slopes = metadata.alibi_slopes
ctx.philox_seed = philox_seed ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False ctx.return_encoded_softmax = metadata.return_encoded_softmax
return o, encoded_softmax return o, encoded_softmax
triton_attention = _attention.apply triton_attention_rocm = _attention.apply
def scale_fp8(t, scale=None):
t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]),
scale)
return t_scaled.reshape(t.shape), scale_out
def maybe_quantize_fp8(t, scale):
eight_bit_dtype = current_platform.fp8_dtype()
if t.dtype != eight_bit_dtype:
t, _ = scale_fp8(t, scale)
return t
def check_and_maybe_quantize_qkv(q, k, v, fp8_scales):
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
q = maybe_quantize_fp8(q, q_scale)
k = maybe_quantize_fp8(k, k_scale)
v = maybe_quantize_fp8(v, v_scale)
return q, k, v
# query - [num_tokens, num_heads, head_size]
# key - [num_tokens, num_kv_heads, head_size]
# value - [num_tokens, num_kv_heads, head_size
# output - [num_tokens, num_heads, head_size]
def triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlens_q: int,
max_seqlens_k: int,
causal: bool = False,
sm_scale: float = 1.0,
bias: Optional[torch.Tensor] = None,
fp8_scales: Optional[tuple[float, ...]] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if fp8_scales is not None:
q_descale, k_descale, v_descale, p_scale = fp8_scales
else:
q_descale = k_descale = v_descale = p_scale = None
attn_metadata = MetaData(sm_scale=sm_scale,
max_seqlens_q=max_seqlens_q,
max_seqlens_k=max_seqlens_k,
causal=causal,
bias=bias,
output_dtype=q.dtype,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=input_scale)
if fp8_scales is not None:
q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales)
return triton_attention_rocm(q, k, v, o, attn_metadata)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment