Unverified Commit 8e4b351a authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to...


[Kernel][Triton][FP8] Adding fp8 and variable length sequence support to Triton FAv2 kernel (#12591)
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 9869453c
# SPDX-License-Identifier: Apache-2.0
"""Tests for the triton_flash_attention kernel
Run `pytest tests/kernels/test_triton_flash_attention.py`.
"""
import pytest
import torch
from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS,
MetaData,
compute_alibi_tensor,
scale_fp8,
triton_attention_rocm)
from vllm.platforms import current_platform
class ReferenceAttention:
def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype,
input_metadata):
self.Z = Z
self.HQ = HQ
self.HK = HK
self.N_CTX_Q = N_CTX_Q
self.N_CTX_K = N_CTX_K
self.D_HEAD = D_HEAD
self.use_alibi = use_alibi
self.dtype = dtype
self.input_metadata = input_metadata
def fwd(self, q, k, v):
scores = torch.einsum('bhqd,bhkd->bhqk', q,
k).float() * self.input_metadata.sm_scale
if self.input_metadata.causal:
mask = torch.tril(torch.ones(self.N_CTX_Q,
self.N_CTX_K,
device="cuda"),
diagonal=self.N_CTX_K - self.N_CTX_Q)
scores[:, :, mask == 0] = float("-inf")
if self.input_metadata.bias is not None:
scores += self.input_metadata.bias
if self.use_alibi:
scores += compute_alibi_tensor(self.input_metadata.alibi_slopes,
self.N_CTX_Q, self.N_CTX_K)
p = torch.softmax(scores, dim=-1)
if self.input_metadata.causal:
# If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going
# into softmax. This creates a row of NaNs as -inf - -inf == NaN.
# So we fix this by converting the NaNs to 0s, which is what they
# should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask == 1] = 0
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v)
# compare
if self.input_metadata.layout == 'bshd':
ref_out = ref_out.transpose(1, 2).clone()
return ref_out
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
self.dtype)
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
self.dtype)
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
self.dtype)
result = self.fwd(q, k, v)
if self.input_metadata.o_scale is not None:
result, _ = scale_fp8(result, self.input_metadata.o_scale)
return result
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
k_descale, v_descale = (self.input_metadata.k_descale,
self.input_metadata.v_descale)
k_dequantized = (k_quantized.to(torch.float32) *
k_descale.to(torch.float32)).to(self.dtype)
v_dequantized = (v_quantized.to(torch.float32) *
v_descale.to(torch.float32)).to(self.dtype)
return self.fwd(q, k_dequantized, v_dequantized)
def varlen_fwd(self, q, k, v, is_mqa=False):
ref_out = torch.empty_like(q)
if is_mqa:
# Make KV look like HQ/HK "groups" of HK. Later, we will reshape so
# the size aligns with Q.
k_ref = k.view(k.shape[0], k.shape[1], 1,
k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1)
v_ref = v.view(v.shape[0], v.shape[1], 1,
v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1)
else:
k_ref = k
v_ref = v
for i in range(0, self.input_metadata.num_contexts):
start_q, start_k = self.input_metadata.cu_seqlens_q[
i], self.input_metadata.cu_seqlens_k[i]
end_q, end_k = self.input_metadata.cu_seqlens_q[
i + 1], self.input_metadata.cu_seqlens_k[i + 1]
k_curr = k_ref[start_k:end_k]
v_curr = v_ref[start_k:end_k]
if is_mqa:
k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3])
v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3])
scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q],
k_curr).float()
p = torch.softmax(scores * self.input_metadata.sm_scale,
dim=-1).half()
ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr)
return ref_out
def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False):
q_descale = None
if not fp8_kv:
q, q_descale = scale_fp8(q)
k, k_descale = scale_fp8(k)
v, v_descale = scale_fp8(v)
# In real world use case, the p scale would be a parameter trained by the
# model.
p_scale = None
o_scale = torch.rand(1, device="cuda",
requires_grad=False) if use_o_scale else None
return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale
def input_helper(
Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
layout=None,
use_alibi=None,
causal=None,
is_fp8=False,
fp8_kv=False,
use_o_scale=False,
use_bias=False,
):
assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
current_platform.seed_everything(0)
# Initialize q, k, v
if layout == 'bhsd':
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
elif layout == 'bshd':
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD)
if use_alibi:
# for n heads the set of slopes is the geometric sequence that starts
# 2^(-8/n)
alibi_slopes = torch.tensor(
[2**(-8 / HQ * i) for i in range(1, HQ + 1)],
dtype=torch.float32,
device="cuda").repeat(Z, 1)
else:
alibi_slopes = None
if use_bias:
bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K),
dtype=dtype,
device="cuda",
requires_grad=False)
else:
bias = None
q = torch.randn(q_tensor_shape,
dtype=dtype,
device="cuda",
requires_grad=False)
k = torch.randn(k_tensor_shape,
dtype=dtype,
device="cuda",
requires_grad=False)
v = torch.randn(k_tensor_shape,
dtype=dtype,
device="cuda",
requires_grad=False)
if is_fp8:
(q, k, v, q_descale, k_descale, v_descale, p_scale,
o_scale) = quantize_input(q,
k,
v,
use_o_scale=use_o_scale,
fp8_kv=fp8_kv)
else:
q_descale = k_descale = v_descale = p_scale = o_scale = None
input_metadata = MetaData(sm_scale=D_HEAD**-0.5,
max_seqlens_q=N_CTX_Q,
max_seqlens_k=N_CTX_K,
layout=layout,
alibi_slopes=alibi_slopes,
alibi_batch=Z,
alibi_nheads=HQ,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=o_scale,
bias=bias,
seqlen_q=N_CTX_Q,
seqlen_k=N_CTX_K)
return q, k, v, input_metadata
def varlen_input_helper(Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
equal_seqlens=False):
current_platform.seed_everything(0)
# Random sequence lengths. Using N_CTX as kind of max of sum of individual
# seqs
if not equal_seqlens:
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1,
max_seqlens_q + 1, (Z, ),
dtype=torch.int32)
seqlens_k = torch.randint(1,
max_seqlens_k + 1, (Z, ),
dtype=torch.int32)
else:
seqlens_q = torch.full((Z, ), N_CTX_Q // Z)
seqlens_k = torch.full((Z, ), N_CTX_K // Z)
# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat([
torch.tensor([0], dtype=torch.int32),
seqlens_q.cumsum(dim=0, dtype=torch.int32)
])
cu_seqlens_k = torch.cat([
torch.tensor([0], dtype=torch.int32),
seqlens_k.cumsum(dim=0, dtype=torch.int32)
])
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
# Initialize q, k, v with variable lengths
total_q = cu_seqlens_q[-1].item()
total_k = cu_seqlens_k[-1].item()
q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype,
device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((total_k, HK, D_HEAD), dtype=dtype,
device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((total_k, HK, D_HEAD), dtype=dtype,
device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
return q, k, v, input_metadata
@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [
(1, 48, 12, 1, 1, 64),
(4, 4, 4, 128, 128, 65),
(16, 48, 48, 1, 1, 128),
(64, 48, 24, 3, 3, 128),
(4, 4, 4, 113, 123, 1),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_alibi', [True, False])
@pytest.mark.parametrize('layout', ['bshd'])
def test_op_fwd(Z,
HQ,
HK,
N_CTX_Q,
N_CTX_K,
D_HEAD,
causal,
use_alibi,
layout,
dtype=torch.float16):
current_platform.seed_everything(0)
q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD,
dtype, layout, use_alibi, causal)
o = torch.empty_like(q)
# triton implementation
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
# Transpose here if layout is bshd so we have same reference code for all
# layouts
if layout == 'bshd':
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
# Replicate K and V if using MQA/GQA
if HQ != HK:
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
k.shape[3]).expand(-1, -1, HQ // HK, -1,
-1).reshape(k.shape[0], -1, k.shape[2],
k.shape[3])
v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
v.shape[3]).expand(-1, -1, HQ // HK, -1,
-1).reshape(v.shape[0], -1, v.shape[2],
v.shape[3])
ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD,
use_alibi, dtype, input_metadata)
ref_out = ref_impl.fwd(q, k, v)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('layout', ['bhsd'])
@pytest.mark.parametrize('use_o_scale', [True, False])
@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0),
reason="Triton FP8 requires CUDA 9.0 or higher")
def test_op_fwd_fp8(Z,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
causal,
layout,
use_o_scale,
dtype=torch.float32):
current_platform.seed_everything(0)
# Disable grad to save memory it won't run into OOM on CI machine.
# q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD,
# dtype, layout)
q_quantized, k_quantized, v_quantized, input_metadata = input_helper(
Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
causal=causal,
layout=layout,
is_fp8=True,
use_o_scale=use_o_scale)
o = torch.empty_like(q_quantized) if use_o_scale else None
tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized,
o, input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized)
# compare
torch.testing.assert_close(ref_out.to(torch.float32),
tri_out.to(torch.float32),
atol=7e-2,
rtol=2e-1)
@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
(4, 4, 113, 123, 1),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('layout', ['bhsd'])
def test_op_fwd_fp8_kv(Z,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
causal,
layout,
dtype=torch.float32):
current_platform.seed_everything(0)
q, k_quantized, v_quantized, input_metadata = input_helper(Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
causal=causal,
layout=layout,
is_fp8=True,
fp8_kv=True)
o = torch.empty_like(q)
tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o,
input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized)
torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1)
@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(4, 48, 1, 1, 64),
(4, 48, 1, 1, 128),
(4, 48, 3, 3, 128),
(4, 4, 128, 128, 65),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_bias', [True])
@pytest.mark.parametrize('dtype', [torch.bfloat16])
def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
current_platform.seed_everything(0)
q, k, v, input_metadata = input_helper(Z,
H,
H,
N_CTX_Q,
N_CTX_K,
D_HEAD,
dtype,
layout='bhsd',
causal=causal,
use_bias=use_bias)
o = torch.empty_like(q)
# triton implementation
tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.fwd(q, k, v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
# NOTE: Uses thd layout, so also tests thd.
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64),
(4, 48, 512, 64),
(16, 48, 512, 64),
(64, 48, 128, 128)])
@pytest.mark.parametrize('causal', [True, False])
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX,
D_HEAD, dtype)
tri_out = torch.empty_like(q)
triton_attention_rocm(q, k, v, tri_out, input_metadata)
ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype,
input_metadata)
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
# NOTE: Uses thd layout, so also tests thd.
@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64),
(4, 48, 12, 256, 64),
(4, 48, 4, 512, 64),
(4, 64, 16, 128, 128)])
@pytest.mark.parametrize('causal', [False])
def test_op_varlen_mqa_fwd(Z,
HQ,
HK,
N_CTX,
D_HEAD,
causal,
dtype=torch.float16):
q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX,
D_HEAD, dtype)
tri_out = torch.empty_like(q)
triton_attention_rocm(q, k, v, tri_out, input_metadata)
ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False,
dtype, input_metadata)
ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True)
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Fused Attention Fused Attention
=============== ===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao This is a Triton implementation of the Flash Attention v2 algorithm
(https://tridao.me/publications/flash2/flash2.pdf) See https://tridao.me/publications/flash2/flash2.pdf
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported: Credits:
AMD Triton kernels team
OpenAI kernel team
1) Fwd with causal masking Currently only the forward kernel is supported, and contains these features:
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims 1) Fwd with causal masking
2) Arbitrary Q and KV sequence lengths
3) Arbitrary head sizes
4) Multi and grouped query attention
5) Variable sequence lengths
6) ALiBi and matrix bias
7) FP8 support
""" """
from typing import Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
torch_dtype: tl.constexpr = torch.float16 from vllm import _custom_ops as ops
from vllm.platforms import current_platform
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = current_platform.fp8_dtype()
default_float8_info = torch.finfo(default_eight_bit_dtype_torch)
FP8_MIN = triton.language.constexpr(default_float8_info.min)
# According to https://github.com/vllm-project/vllm/blob/main
# /csrc/quantization/utils.cuh#L31,
# need to make the max for the uz datatype be 224.0 for accuracy reasons.
FP8_MAX = triton.language.constexpr(
default_float8_info.max if default_eight_bit_dtype_torch !=
torch.float8_e4m3fnuz else 224.0)
class MetaData:
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
eight_bit = False
layout = None
return_encoded_softmax = False
eight_bit_dtype_triton = default_eight_bit_dtype_triton
eight_bit_dtype_torch = default_eight_bit_dtype_torch
output_dtype = None
# Note about layouts:
#
# thd - [num_tokens, num_heads, head_size]
# bshd - [batch_size, seq_len, num_heads, head_size]
# bhsd - [batch_size, num_heads, seq_len, head_size]
#
# This is for each tensor, all tensors must have same layout.
# Q can have num_heads and seq_len differ from from K and V,
# however K and V must agree on this.
#
# Notes about varlen and bias:
# Only one or the other is implemented, meaning can't combine
# both varlen and bias right now.
#
# Note about quantization:
# Only 8-bit quantization supported (for now) and specifically fp8.
# Scales must be tensors.
# o_scale: This is 'output scaling', but comes from parameter called
# 'input_scale', this is applied to the output from the kernel.
# o_scale should be None if none of the other quantization parameters
# are used.
#
# NOTE: Object is in a tentatively good state after initialized, however,
# to verify, call check_args(q,k,v,o) where o is the output tensor.
def __init__(
self,
sm_scale=1.0,
layout=None, # layout can be 'bshd', 'bhsd', or 'thd'
output_dtype=None,
max_seqlens_q=0,
max_seqlens_k=0,
# varlen params
cu_seqlens_q=None, # only 'thd' layout supported for varlen
cu_seqlens_k=None,
# quant params
q_descale=None,
k_descale=None,
v_descale=None,
p_scale=None,
o_scale=None,
# bias params
bias=None, # varlen not implemented for bias
seqlen_q=None,
seqlen_k=None,
# alibi params
alibi_slopes=None,
alibi_batch=None,
alibi_nheads=None,
# causal
causal=None,
):
self.sm_scale = sm_scale
self.output_dtype = output_dtype
self.max_seqlens_q = max_seqlens_q
self.max_seqlens_k = max_seqlens_k
self.layout = layout
if cu_seqlens_q is not None or cu_seqlens_k is not None:
assert cu_seqlens_q is not None and cu_seqlens_k is not None
assert layout is None or layout not in [
'bshd', 'bhsd'
], "Varlen only implemented for thd layout"
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale]
if any(x is not None for x in quant_params):
p_descale = 1.0 / p_scale if p_scale is not None else None
self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale)
if bias is not None:
self.need_bias(bias, seqlen_q, seqlen_k)
if alibi_slopes is not None:
self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads)
if causal is not None and causal:
self.need_causal()
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(
cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(),
self.max_seqlens_q)
self.max_seqlens_k = max(
cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(),
self.max_seqlens_k)
def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale):
self.eight_bit = True
self.q_descale = q_descale
self.k_descale = k_descale
self.v_descale = v_descale
self.p_scale = p_scale
self.p_descale = p_descale
self.o_scale = o_scale
self.use_p_scale = (p_scale is not None) and (
p_descale is not None) and (v_descale is not None)
self.eight_bit_kv = ((q_descale is None) and (k_descale is not None)
and (v_descale is not None))
self.eight_bit_dtype_torch = default_eight_bit_dtype_torch
def need_bias(self, bias, seqlen_q, seqlen_k):
assert bias is not None
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, self)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
if self.eight_bit:
if self.eight_bit_kv:
assert (v.dtype == k.dtype
and k.dtype == self.eight_bit_dtype_torch)
assert q.dtype != k.dtype
assert (self.v_descale is not None) and (self.k_descale
is not None)
else:
assert (q.dtype == k.dtype and q.dtype == v.dtype
and q.dtype == self.eight_bit_dtype_torch)
assert (self.q_descale
is not None) and (self.k_descale
is not None) and (self.v_descale
is not None)
if self.use_p_scale:
assert (self.p_scale is not None) and (self.p_descale
is not None)
else:
assert (q.dtype == k.dtype) and (q.dtype == v.dtype)
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
@triton.jit @triton.jit
...@@ -38,40 +244,85 @@ def max_fn(x, y): ...@@ -38,40 +244,85 @@ def max_fn(x, y):
return tl.math.max(x, y) return tl.math.max(x, y)
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit @triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): def masked_load(ptrs, offset_first, offset_second, boundary_first,
ms = tl.arange(0, m) boundary_second):
ns = tl.arange(0, n) if offset_first is not None and offset_second is not None:
return philox_offset + ms[:, None] * stride + ns[None, :] mask = (offset_first[:, None] < boundary_first) & \
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit @triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): def compute_alibi_block(alibi_slope,
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, seqlen_q,
stride).to(tl.uint32) seqlen_k,
# TODO: use tl.randint for better performance offs_m,
return tl.rand(philox_seed, rng_offsets) offs_n,
transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to
# the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is
# masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that
# spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5,
# offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] =
# [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q -
offs_n[None, :])
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
@triton.jit def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k):
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): q_idx = torch.arange(seqlen_q, dtype=torch.int32,
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
stride) k_idx = torch.arange(seqlen_k, dtype=torch.int32,
rng_keep = rng_output > dropout_p device="cuda").unsqueeze(0) # (1, N_CTX_K)
return rng_keep relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q -
k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
@triton.jit @triton.jit
def load_fn(block_ptr, first, second, pad): def quant_fp8(x, scale):
if first and second: x *= scale
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) x = tl.clamp(x, FP8_MIN, FP8_MAX)
elif first: return x
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit @triton.jit
...@@ -80,58 +331,68 @@ def _attn_fwd_inner( ...@@ -80,58 +331,68 @@ def _attn_fwd_inner(
l_i, l_i,
m_i, m_i,
q, q,
K_block_ptr, k_ptrs,
V_block_ptr, v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
actual_seqlen_k, actual_seqlen_k,
dropout_p, actual_seqlen_q,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_softmax_block_ptr, encoded_sm_ptrs,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
bias_ptr, alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr, OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, SHOULD_PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr, SHOULD_MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_PADDED_HEAD: tl.constexpr,
PADDED_HEAD: tl.constexpr, IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
QK_SCALE: tl.constexpr,
IS_EIGHT_BIT_GEMM: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
# loop over k, v, and update accumulator # loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N): for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if # For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range. # we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn( k_offs_n = start_n + tl.arange(0,
K_block_ptr, BLOCK_N) if SHOULD_MASK_STEPS else None
PADDED_HEAD, k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0), k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL,
"zero", actual_seqlen_k)
) if SHOULD_PRE_LOAD_V:
if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed.
v = load_fn( v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
V_block_ptr, IS_ACTUAL_BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need # We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n # to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block. # TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102 if SHOULD_MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to # If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size # mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not
# if not is_modulo_mn. last step might get wasted but that is okay. # is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case. # check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], boundary_m = tl.full([BLOCK_M],
...@@ -144,167 +405,276 @@ def _attn_fwd_inner( ...@@ -144,167 +405,276 @@ def _attn_fwd_inner(
causal_boundary = start_n + offs_n_causal causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf")) qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ---- # -- compute qk ----
qk += tl.dot(q, k) if IS_EIGHT_BIT_GEMM:
if bias_ptr is not None: qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) *
bias = load_fn(bias_ptr, False, MASK_STEPS QK_SCALE)
and (n_extra_tokens != 0), "zero") else:
# While bias is added after multiplying qk with sm_scale, our if IS_EIGHT_BIT_KV:
# optimization to use 2^x instead of e^x results in an additional k = (k * k_descale).to(q.type.element_ty)
# scale factor of log2(e) which we must also multiply the bias with. qk += (tl.dot(q, k) * QK_SCALE)
qk += bias * 1.44269504089
if bias_ptrs is not None:
bias_offs_n = start_n + tl.arange(
0, BLOCK_N) if SHOULD_MASK_STEPS else None
bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q,
actual_seqlen_k)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an
# additional scale factor of log2(e) which we must also multiply
# the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q,
actual_seqlen_k,
global_m_positions,
global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None] qk = qk - m_ij[:, None]
p = tl.math.exp2(qk) p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout # CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT: if SHOULD_RETURN_ENCODED_SOFTMAX:
philox_offset = (batch_philox_offset + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty))
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator -- # -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij) alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
if not PRE_LOAD_V: if not SHOULD_PRE_LOAD_V:
v = load_fn( v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k,
V_block_ptr, IS_ACTUAL_BLOCK_DMODEL)
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i # -- update m_i and l_i
l_i = l_i * alpha + l_ij l_i = l_i * alpha + l_ij
# update m_i and l_i # update m_i and l_i
m_i = m_ij m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) if IS_EIGHT_BIT_GEMM:
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if USE_P_SCALE:
if bias_ptr is not None: p = quant_fp8(p, p_scale).to(QUANT_DTYPE)
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) acc += tl.dot(p, v)
if RETURN_ENCODED_SOFTMAX: else:
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, # v is in eight_bit but p is not, we want the gemm in p's type
(0, BLOCK_N)) acc += tl.dot(p, v.to(p.type.element_ty))
else:
if IS_EIGHT_BIT_KV:
v = (v * v_descale).to(p.type.element_ty)
acc += tl.dot(p.to(v.type.element_ty), v)
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += BLOCK_N
return acc, l_i, m_i return acc, l_i, m_i
@triton.autotune( def get_cdna_autotune_configs():
configs=[ return [
triton.Config( triton.Config(
{ {
"BLOCK_M": 256, 'BLOCK_M': 128,
"BLOCK_N": 64, 'BLOCK_N': 128,
"waves_per_eu": 2, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=4),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 128,
"BLOCK_N": 128, 'BLOCK_N': 64,
"waves_per_eu": 2, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=4),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 256, 'BLOCK_M': 128,
"BLOCK_N": 128, 'BLOCK_N': 64,
"waves_per_eu": 2, 'waves_per_eu': 3,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=4),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 128,
"BLOCK_N": 64, 'BLOCK_N': 64,
"waves_per_eu": 1, 'waves_per_eu': 1,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=4),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 128,
"BLOCK_N": 64, 'BLOCK_N': 32,
"waves_per_eu": 3, 'waves_per_eu': 2,
"PRE_LOAD_V": True, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=4),
), ], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
"BLOCK_M": 128, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 32,
"waves_per_eu": 3, 'waves_per_eu': 4,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 64, 'BLOCK_M': 32,
"BLOCK_N": 64, 'BLOCK_N': 32,
"waves_per_eu": 4, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
),
triton.Config( triton.Config(
{ {
"BLOCK_M": 32, 'BLOCK_M': 32,
"BLOCK_N": 32, 'BLOCK_N': 16,
"waves_per_eu": 4, 'waves_per_eu': 4,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=8, num_warps=2),
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config( triton.Config(
{ {
"BLOCK_M": 16, 'BLOCK_M': 32,
"BLOCK_N": 16, 'BLOCK_N': 16,
"waves_per_eu": 1, 'waves_per_eu': 2,
"PRE_LOAD_V": False, 'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4, num_warps=2),
), triton.Config(
], {
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], 'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=2),
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=2),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_general_autotune_configs():
return [
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 128,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 64,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 32,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def has_cdna_target():
ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_CDNA_TARGETS
def is_rocm_cdna():
return current_platform.is_rocm() and has_cdna_target()
def get_autotune_configs():
if is_rocm_cdna():
return get_cdna_autotune_configs()
elif current_platform.is_rocm():
return get_rdna_autotune_configs()
else:
return get_general_autotune_configs()
autotune_configs, autotune_keys = get_autotune_configs()
@triton.autotune(
configs=autotune_configs,
key=autotune_keys,
use_cuda_graph=True,
) )
@triton.jit @triton.jit
def attn_fwd( def attn_fwd(
...@@ -312,38 +682,53 @@ def attn_fwd( ...@@ -312,38 +682,53 @@ def attn_fwd(
K, K,
V, V,
bias, bias,
sm_scale, SM_SCALE: tl.constexpr,
L, L,
Out, Out,
stride_qz, stride_qz: tl.int64,
stride_qh, stride_qh: tl.int64,
stride_qm, stride_qm: tl.int64,
stride_qk, stride_qk: tl.int64,
stride_kz, stride_kz: tl.int64,
stride_kh, stride_kh: tl.int64,
stride_kn, stride_kn: tl.int64,
stride_kk, stride_kk: tl.int64,
stride_vz, stride_vz: tl.int64,
stride_vh, stride_vh: tl.int64,
stride_vk, stride_vk: tl.int64,
stride_vn, stride_vn: tl.int64,
stride_oz, stride_oz: tl.int64,
stride_oh, stride_oh: tl.int64,
stride_om, stride_om: tl.int64,
stride_on, stride_on: tl.int64,
stride_bz, stride_bz: tl.int64,
stride_bh, stride_bh: tl.int64,
stride_bm, stride_bm: tl.int64,
stride_bn, stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
dropout_p,
philox_seed, philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
alibi_slopes,
HQ: tl.constexpr, HQ: tl.constexpr,
HK: tl.constexpr, HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, IS_ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, VARLEN: tl.constexpr,
...@@ -351,24 +736,39 @@ def attn_fwd( ...@@ -351,24 +736,39 @@ def attn_fwd(
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, SHOULD_PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr, USE_BIAS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr,
USE_P_SCALE: tl.constexpr,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1) if o_descale_ptr is not None:
off_z = tl.program_id(2) o_descale = tl.load(o_descale_ptr)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) start_m: tl.int64 = tl.program_id(0)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN: if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too # We have a one-size-fits-all grid in id(0). Some seqlens might be
# small for all start_m so for those we return early. # too small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q: if start_m * BLOCK_M > seqlen_q:
return continue_condition = False
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
...@@ -378,156 +778,186 @@ def attn_fwd( ...@@ -378,156 +778,186 @@ def attn_fwd(
seqlen_q = MAX_SEQLENS_Q seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking. if continue_condition:
# This is because for seqlen_q > seqlen_k, M rows of the attn scores # Now we compute whether we need to exit early due to causal
# are completely masked, resulting in 0s written to the output, and # masking. This is because for seqlen_q > seqlen_k, M rows of the
# inf written to LSE. We don't need to do any GEMMs in this case. # attn scores are completely masked, resulting in 0s written to the
# This block of code determines what N is, and if this WG is operating # output, and inf written to LSE. We don't need to do any GEMMs in
# on those M rows. # this case. This block of code determines what N is, and if this
# WG is operating on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N) n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL: if (IS_CAUSAL):
# If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means # If seqlen_q != seqlen_k, attn scores are rectangular which
# the causal mask boundary is bottom right aligned, and ends at either # means the causal mask boundary is bottom right aligned, and
# the top edge (seqlen_q < seqlen_k) or left edge. # ends at either the top edge (seqlen_q < seqlen_k) or left
# This captures the decrease in n_blocks if we have a rectangular attn # edge. This captures the decrease in n_blocks if we have a
# matrix # rectangular attn matrix
n_blocks_seqlen = cdiv_fn( n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only # This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks # if IS_CAUSAL. Otherwise we want to always iterate through all
# n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen) n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is # If we have no blocks after adjusting for seqlen deltas, this
# part of the blocks that are all 0. We exit early. # WG is part of the blocks that are all 0. We exit early.
if n_blocks <= 0: if n_blocks <= 0:
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
off_h_q * stride_oh) cu_seqlens_q_start * stride_om)
O_block_ptr = tl.make_block_ptr( o_ptrs = (o_offset + offs_m[:, None] * stride_om +
base=Out + o_offset, offs_d[None, :] * stride_on)
shape=(seqlen_q, BLOCK_DMODEL), acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
strides=(stride_om, stride_on), o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to(
offsets=(start_m * BLOCK_M, 0), [BLOCK_M, BLOCK_DMODEL])
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result # We still need to write 0s to the result
# tl.store(O_block_ptr, tl.store(o_ptrs, acc, mask=o_ptrs_mask)
# acc.to(Out.type.element_ty), boundary_check=(0,1)) # The tensor allocated for L is based on MAX_SEQLENS_Q as
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q # that is statically known.
# + offs_m l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
off_h_q * MAX_SEQLENS_Q + offs_m)
# We store inf to LSE, not -inf because in the bwd pass, # We store inf to LSE, not -inf because in the bwd pass,
# we subtract this # we subtract this from qk which makes it -inf, such that
# from qk which makes it -inf, such that exp(qk - inf) = 0 # exp(qk - inf) = 0 for these masked blocks.
# for these masked blocks. l_value = tl.full([BLOCK_M],
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) value=float("inf"),
# tl.store(l_ptrs, l) dtype=tl.float32)
# TODO: Should dropout and return encoded softmax be handled here? l_ptrs_mask = offs_m < MAX_SEQLENS_Q
return tl.store(l_ptrs, l_value, mask=l_ptrs_mask)
# TODO: Should dropout and return encoded softmax be
# handled here too?
continue_condition = False
# return
if continue_condition:
# If MQA / GQA, set the K and V head offsets appropriately. # If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0 n_extra_tokens = 0
if seqlen_k < BLOCK_N: if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N: elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N n_extra_tokens = seqlen_k % BLOCK_N
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL
!= BLOCK_DMODEL)
# Compute pointers for all the tensors used in this kernel. # Compute pointers for all the tensors used in this kernel.
q_offset = (off_z * stride_qz + off_h_q * stride_qh + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm) cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr( q_ptrs = (q_offset + offs_m[:, None] * stride_qm +
base=Q + q_offset, offs_d[None, :] * stride_qk)
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), k_offset = (K + off_z * stride_kz + off_h_k * stride_kh +
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn) cu_seqlens_k_start * stride_kn)
K_block_ptr = tl.make_block_ptr( k_ptrs = (k_offset + offs_d[:, None] * stride_kk +
base=K + k_offset, offs_n[None, :] * stride_kn)
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), v_offset = (V + off_z * stride_vz + off_h_k * stride_vh +
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk) cu_seqlens_k_start * stride_vk)
V_block_ptr = tl.make_block_ptr( v_ptrs = (v_offset + offs_n[:, None] * stride_vk +
base=V + v_offset, offs_d[None, :] * stride_vn)
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), # Compute pointers for all scale tensors used in this kernel.
strides=(stride_vk, stride_vn),
offsets=(0, 0), IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & (
block_shape=(BLOCK_N, BLOCK_DMODEL), not IS_EIGHT_BIT_KV)
order=(1, 0), if IS_EIGHT_BIT:
) if k_descale_has_singleton:
if BIAS_TYPE != 0: k_descale_ptrs = k_descale_ptr
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else: else:
bias_ptr = None k_descale_ptrs = k_descale_ptr + off_h_k
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \ if v_descale_has_singleton:
+ (off_z * HQ + off_h_q) \ v_descale_ptrs = v_descale_ptr
* seqlen_q * seqlen_k
else: else:
v_descale_ptrs = v_descale_ptr + off_h_k
if not IS_EIGHT_BIT_KV:
if q_descale_has_singleton:
q_descale_ptrs = q_descale_ptr
else:
q_descale_ptrs = q_descale_ptr + off_h_q
if USE_P_SCALE:
if p_descale_has_singleton:
p_scale_ptrs = p_scale_ptr
p_descale_ptrs = p_descale_ptr
else:
p_scale_ptrs = p_scale_ptr + off_h_q
p_descale_ptrs = p_descale_ptr + off_h_q
if USE_BIAS:
bias_offset = off_h_q * stride_bh
bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm +
offs_n[None, :] * stride_bn)
else:
bias_ptrs = None
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
batch_philox_offset = 0 batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout. # We can ask to return the dropout mask without doing any
# In this case, we return an invalid pointer so indicate the mask is not i # dropout. In this case, we return an invalid pointer so
# valid. # indicate the mask is not valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset. if SHOULD_RETURN_ENCODED_SOFTMAX:
if RETURN_ENCODED_SOFTMAX: encoded_sm_base = (encoded_softmax +
encoded_softmax_block_ptr = tl.make_block_ptr( off_h_q * seqlen_q * seqlen_k)
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, encoded_sm_ptrs = (encoded_sm_base +
shape=(seqlen_q, seqlen_k), offs_m[:, None] * seqlen_k +
strides=(seqlen_k, 1), offs_n[None, :])
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else: else:
encoded_softmax_block_ptr = 0 encoded_sm_ptrs = None
# initialize pointer to m and l # initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not # scale sm_scale by log_2(e) and use 2^x in the loop as we do
# have native e^x support in HW. # not have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089 QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks. # Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, padded_head, "zero") q_ptrs_mask = offs_m[:, None] < seqlen_q
q = (q * qk_scale).to(Q_block_ptr.type.element_ty) if USE_PADDED_HEAD:
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
if IS_EIGHT_BIT:
k_descale = tl.load(k_descale_ptrs)
v_descale = tl.load(v_descale_ptrs)
q_descale = None if IS_EIGHT_BIT_KV else tl.load(
q_descale_ptrs)
if USE_P_SCALE:
p_scale = tl.load(p_scale_ptrs)
p_descale = tl.load(p_descale_ptrs)
else:
p_scale = None
p_descale = None
else:
q_descale = None
k_descale = None
v_descale = None
p_scale = None
p_descale = None
# Here we compute how many full and masked blocks we have. # Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0 padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL: if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks. # There are always at least BLOCK_M // BLOCK_N masked
# Additionally there might be one more due to dissimilar seqlens. # blocks. Additionally there might be one more due to
# dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else: else:
# Padding on Q does not need to be masked in the FA loop. # Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional # if IS_CAUSAL, not is_modulo_mn does not always result in an
# block. In this case we might exceed n_blocks so pick the min. # additional block. In this case we might exceed n_blocks so
# pick the min.
masked_blocks = min(masked_blocks, n_blocks) masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks n_full_blocks = n_blocks - masked_blocks
block_min = 0 block_min = 0
block_max = n_blocks * BLOCK_N block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its # Compute for full blocks. Here we set causal to false
# value because there is no masking. Similarly we do not need padding. # regardless of its actual value because there is no masking.
# Similarly we do not need padding.
if n_full_blocks > 0: if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i = _attn_fwd_inner(
...@@ -535,21 +965,29 @@ def attn_fwd( ...@@ -535,21 +965,29 @@ def attn_fwd(
l_i, l_i,
m_i, m_i,
q, q,
K_block_ptr, k_ptrs,
V_block_ptr, v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
seqlen_k, seqlen_k,
dropout_p, seqlen_q,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_softmax_block_ptr, encoded_sm_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_min,
block_max, block_max,
0, 0,
0, 0,
0, 0,
bias_ptr, alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
# IS_CAUSAL, .... # IS_CAUSAL, ....
False, False,
BLOCK_M, BLOCK_M,
...@@ -557,265 +995,381 @@ def attn_fwd( ...@@ -557,265 +995,381 @@ def attn_fwd(
BLOCK_N, BLOCK_N,
offs_m, offs_m,
offs_n, offs_n,
# _, MASK_STEPS, ... # _, SHOULD_MASK_STEPS, ...
PRE_LOAD_V, SHOULD_PRE_LOAD_V,
False, False,
ENABLE_DROPOUT, SHOULD_RETURN_ENCODED_SOFTMAX,
RETURN_ENCODED_SOFTMAX, USE_PADDED_HEAD,
padded_head, IS_ACTUAL_BLOCK_DMODEL,
) QK_SCALE,
IS_EIGHT_BIT_GEMM,
USE_P_SCALE,
IS_EIGHT_BIT_KV,
QUANT_DTYPE)
block_min = block_max block_min = block_max
block_max = n_blocks * BLOCK_N block_max = n_blocks * BLOCK_N
tl.debug_barrier() tl.debug_barrier()
# Remaining blocks, if any, are full / not masked. # Remaining blocks, if any, are full / not masked.
if masked_blocks > 0: if (masked_blocks > 0):
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 if IS_CAUSAL:
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) offs_n_causal = offs_n + (seqlen_q - seqlen_k)
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) else:
if bias_ptr is not None: offs_n_causal = 0
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) k_ptrs += n_full_blocks * BLOCK_N * stride_kn
if RETURN_ENCODED_SOFTMAX: v_ptrs += n_full_blocks * BLOCK_N * stride_vk
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, if USE_BIAS:
(0, n_full_blocks)) bias_ptrs += n_full_blocks * BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i = _attn_fwd_inner(
acc, acc,
l_i, l_i,
m_i, m_i,
q, q,
K_block_ptr, k_ptrs,
V_block_ptr, v_ptrs,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
seqlen_k, seqlen_k,
dropout_p, seqlen_q,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_softmax_block_ptr, encoded_sm_ptrs,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
bias_ptr, alibi_slope,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL, IS_CAUSAL,
BLOCK_M, BLOCK_M,
BLOCK_DMODEL, BLOCK_DMODEL,
BLOCK_N, BLOCK_N,
offs_m, offs_m,
offs_n, offs_n,
# _, MASK_STEPS, ... # _, SHOULD_MASK_STEPS, ...
PRE_LOAD_V, SHOULD_PRE_LOAD_V,
True, True,
ENABLE_DROPOUT, SHOULD_RETURN_ENCODED_SOFTMAX,
RETURN_ENCODED_SOFTMAX, USE_PADDED_HEAD,
padded_head, IS_ACTUAL_BLOCK_DMODEL,
) QK_SCALE,
IS_EIGHT_BIT_GEMM,
USE_P_SCALE,
IS_EIGHT_BIT_KV,
QUANT_DTYPE)
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV:
if USE_P_SCALE:
acc *= p_descale
acc *= v_descale
# epilogue # epilogue
acc = acc / l_i[:, None] # This helps the compiler do Newton Raphson on l_i vs on acc
if ENABLE_DROPOUT: # which is much larger.
acc = acc / (1 - dropout_p) l_recip = 1 / l_i[:, None]
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, acc = acc * l_recip
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here # If seqlen_q > seqlen_k but the delta is not a multiple of
# and store 0s where there are NaNs as these rows should've been zeroed out. # BLOCK_M, then we have one block with a row of all NaNs which
# come from computing softmax over a row of all
# -infs (-inf - inf = NaN). We check for that here and store 0s
# where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k causal_start_idx = seqlen_q - seqlen_k
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102
if o_descale_ptr is not None:
acc = quant_fp8(acc, o_descale)
acc = acc.to(Out.type.element_ty) acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102 if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: if (causal_start_idx > start_m_idx
and causal_start_idx < end_m_idx):
out_mask_boundary = tl.full((BLOCK_DMODEL, ), out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx, causal_start_idx,
dtype=tl.int32) dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :]) >= out_mask_boundary[None, :])
z = 0.0 z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) acc = tl.where(out_ptrs_mask, acc,
z.to(acc.type.element_ty))
# write back LSE # write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q +
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last off_h_q * MAX_SEQLENS_Q + offs_m)
# few rows. This is only true for the last M block. For others, # If seqlen_q not multiple of BLOCK_M, we need to mask out the
# overflow_size will be -ve # last few rows. This is only true for the last M block.
# overflow_size = end_m_idx - seqlen_q # For others, overflow_size will be -ve
# if overflow_size > 0: overflow_size = end_m_idx - seqlen_q
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) if overflow_size > 0:
# # This is a > check because mask being 0 blocks the store. boundary = tl.full((BLOCK_M, ),
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) BLOCK_M - overflow_size,
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) dtype=tl.int32)
# else: l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
# tl.store(l_ptrs, m_i + tl.math.log2(l_i)) tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O # write back O
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
off_h_q * stride_oh) cu_seqlens_q_start * stride_om)
O_block_ptr = tl.make_block_ptr( o_ptrs = (o_offset + offs_m[:, None] * stride_om +
base=Out + o_offset, offs_d[None, :] * stride_on)
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
strides=(stride_om, stride_on), if overflow_size > 0:
offsets=(start_m * BLOCK_M, 0), o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
block_shape=(BLOCK_M, BLOCK_DMODEL), if USE_PADDED_HEAD:
order=(1, 0), o_ptrs_mask = o_ptrs_mask & (offs_d[None, :]
) < IS_ACTUAL_BLOCK_DMODEL)
# Need boundary check on this to make sure the padding from the tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1)) def get_shape_from_layout(q, k, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
def check_args( if metadata.layout == 'thd':
q, nheads_q, nheads_k = q.shape[1], k.shape[1]
k, head_size = q.shape[-1]
v, batch = metadata.num_contexts
o, elif metadata.layout == 'bhsd':
varlen=True, batch, nheads_q, _, head_size = q.shape
max_seqlens=None, nheads_k = k.shape[1]
cu_seqlens_q=None, elif metadata.layout == 'bshd':
cu_seqlens_k=None, batch, _, nheads_q, head_size = q.shape
): nheads_k = k.shape[2]
assert q.dim() == k.dim() and q.dim() == v.dim() return batch, nheads_q, nheads_k, head_size
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape def get_strides_from_layout(q, k, v, o, metadata):
total_k, nheads_k, _ = k.shape assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None STRIDE_PERMUTATIONS = {
assert len(cu_seqlens_q) == len(cu_seqlens_k) 'thd': (None, 1, 0, 2),
else: 'bhsd': (0, 1, 2, 3),
assert q.dim() == 4 'bshd': (0, 2, 1, 3),
batch, nheads_q, seqlen_q, head_size = q.shape }
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0 perm = STRIDE_PERMUTATIONS[metadata.layout]
assert k.shape == v.shape stride = lambda x, p: (0 if p is None else x.stride(p))
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] strides = lambda x: (stride(x, p) for p in perm)
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype return tuple(strides(x) for x in [q, k, v, o])
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(ctx, q, k, v, o, metadata: MetaData):
ctx, # NOTE: a large bias tensor leads to overflow during pointer arithmetic
q, if (metadata.bias is not None):
k, assert (metadata.bias.numel() < 2**31)
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args( if o is None:
if metadata.eight_bit:
o = torch.empty_like(
q, q,
k, dtype=metadata.output_dtype if metadata.output_dtype
v, is not None else metadata.eight_bit_dtype_torch)
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else: else:
batch, seqlen_q, nheads_q, head_size = q.shape o = torch.empty_like(q, dtype=q.dtype)
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32. metadata.check_args(q, k, v, o)
unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims: batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
padded_d_model = None q, k, metadata)
for i in unpadded_head_dims: q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(
if i > head_size: q, k, v, o, metadata)
padded_d_model = i
break
assert padded_d_model is not None
else:
padded_d_model = head_size
grid = lambda META: ( # Get closest power of 2 over or equal to 32.
triton.cdiv(max_seqlens_q, META["BLOCK_M"]), padded_d_model = 1 << (head_size - 1).bit_length()
nheads_q, # Smallest head_dim supported is 16. If smaller, the tile in the
batch, # kernel is padded - there is no padding in memory for any dims.
) padded_d_model = max(padded_d_model, 16)
# encoded_softmax is used to validate dropout behavior vs the
# PyTorch SDPA math backend reference. We zero this out to give a
# consistent starting point and then populate it with the output of
# softmax with the sign bit set according to the dropout mask.
# The resulting return allows this mask to be fed into the reference
# implementation for testing only. This return holds no useful output
# aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
device=q.device,
dtype=torch.float32)
else:
encoded_softmax = None encoded_softmax = None
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q),
device=q.device,
dtype=torch.float32)
# Seed the RNG so we get reproducible results for testing. # Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52 philox_seed = 0x1BF52
philox_offset = 0x1D4B42 philox_offset = 0x1D4B42
if bias is not None: if metadata.bias is not None:
bias_strides = ( bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
bias.stride(0), metadata.bias.stride(2), metadata.bias.stride(3))
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else: else:
bias_strides = (0, 0, 0, 0) bias_strides = (0, 0, 0, 0)
if metadata.alibi_slopes is not None:
alibi_strides = (metadata.alibi_slopes.stride(0),
metadata.alibi_slopes.stride(1))
else:
alibi_strides = (0, 0)
if metadata.eight_bit:
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = (
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
else:
q_descale = k_descale = p_scale = None
p_descale = v_descale = o_descale = None
# number of compute units available
NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[
'BLOCK_M']), nheads_q, batch)
attn_fwd[grid]( attn_fwd[grid](
q, q,
k, k,
v, v,
bias, metadata.bias,
sm_scale, metadata.sm_scale,
None, M,
o, o,
*q_strides, *q_strides,
*k_strides, *k_strides,
*v_strides, *v_strides,
*o_strides, *o_strides,
*bias_strides, *bias_strides,
cu_seqlens_q, *alibi_strides,
cu_seqlens_k, q_descale,
dropout_p=0.0, k_descale,
p_scale,
p_descale,
o_descale,
v_descale,
q_descale.numel() == 1 if q_descale is not None else False,
k_descale.numel() == 1 if k_descale is not None else False,
p_descale.numel() == 1 if p_descale is not None else False,
v_descale.numel() == 1 if v_descale is not None else False,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
philox_seed=philox_seed, philox_seed=philox_seed,
philox_offset_base=philox_offset, philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax, encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HQ=nheads_q,
HK=nheads_k, HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size, IS_ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, MAX_SEQLENS_K=metadata.max_seqlens_k,
IS_CAUSAL=causal, IS_CAUSAL=metadata.causal,
VARLEN=True, VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model, BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1, USE_BIAS=metadata.bias is not None,
ENABLE_DROPOUT=False, USE_ALIBI=metadata.alibi_slopes is not None,
RETURN_ENCODED_SOFTMAX=False, SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax,
) IS_EIGHT_BIT=metadata.eight_bit,
USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale,
IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv,
NUM_CU=NUM_CU,
B=batch,
QUANT_DTYPE=metadata.eight_bit_dtype_triton)
ctx.grid = grid ctx.grid = grid
ctx.sm_scale = sm_scale ctx.sm_scale = metadata.sm_scale
ctx.BLOCK_DMODEL = head_size ctx.BLOCK_DMODEL = head_size
ctx.causal = causal ctx.causal = metadata.causal
ctx.dropout_p = 0.0 ctx.alibi_slopes = metadata.alibi_slopes
ctx.philox_seed = philox_seed ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False ctx.return_encoded_softmax = metadata.return_encoded_softmax
return o, encoded_softmax return o, encoded_softmax
triton_attention = _attention.apply triton_attention_rocm = _attention.apply
def scale_fp8(t, scale=None):
t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]),
scale)
return t_scaled.reshape(t.shape), scale_out
def maybe_quantize_fp8(t, scale):
eight_bit_dtype = current_platform.fp8_dtype()
if t.dtype != eight_bit_dtype:
t, _ = scale_fp8(t, scale)
return t
def check_and_maybe_quantize_qkv(q, k, v, fp8_scales):
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
q = maybe_quantize_fp8(q, q_scale)
k = maybe_quantize_fp8(k, k_scale)
v = maybe_quantize_fp8(v, v_scale)
return q, k, v
# query - [num_tokens, num_heads, head_size]
# key - [num_tokens, num_kv_heads, head_size]
# value - [num_tokens, num_kv_heads, head_size
# output - [num_tokens, num_heads, head_size]
def triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlens_q: int,
max_seqlens_k: int,
causal: bool = False,
sm_scale: float = 1.0,
bias: Optional[torch.Tensor] = None,
fp8_scales: Optional[tuple[float, ...]] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if fp8_scales is not None:
q_descale, k_descale, v_descale, p_scale = fp8_scales
else:
q_descale = k_descale = v_descale = p_scale = None
attn_metadata = MetaData(sm_scale=sm_scale,
max_seqlens_q=max_seqlens_q,
max_seqlens_k=max_seqlens_k,
causal=causal,
bias=bias,
output_dtype=q.dtype,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=input_scale)
if fp8_scales is not None:
q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales)
return triton_attention_rocm(q, k, v, o, attn_metadata)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment