Unverified Commit 1b7cfd5a authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm][V0][Attention] Revert to the previous FA triton kernel (#18226)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent da4b69d0
...@@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
and layer._v_scale and layer._prob_scale and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8") and self.kv_cache_dtype == "fp8")
full_scales = ( full_scales = (
layer._q_scale, layer._k_scale, layer._v_scale, layer._q_scale.item(), layer._k_scale.item(),
layer._prob_scale) if use_fp8_scales else None layer._v_scale.item(),
layer._prob_scale.item()) if use_fp8_scales else None
self.triton_attn_func( self.triton_attn_func(
query, query,
key, key,
......
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Fused Attention Fused Attention
=============== ===============
This is a Triton implementation of the Flash Attention v2 algorithm This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
See https://tridao.me/publications/flash2/flash2.pdf (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Credits: Features supported:
AMD Triton kernels team
OpenAI kernel team
Currently only the forward kernel is supported, and contains these features:
1) Fwd with causal masking 1) Fwd with causal masking
2) Arbitrary Q and KV sequence lengths 2) Any sequence lengths without padding (currently fwd kernel only)
3) Arbitrary head sizes 3) Support for different sequence lengths for q and k
4) Multi and grouped query attention 4) Nested tensor API currently does not support dropout or bias.
5) Variable sequence lengths
6) ALiBi and matrix bias
7) FP8 support
""" Not currently supported:
from typing import Optional 1) Non power of two head dims
"""
import torch import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx1x
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] torch_dtype: tl.constexpr = torch.float16
default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = current_platform.fp8_dtype()
default_float8_info = torch.finfo(default_eight_bit_dtype_torch)
FP8_MIN = triton.language.constexpr(default_float8_info.min)
# According to https://github.com/vllm-project/vllm/blob/main
# /csrc/quantization/utils.cuh#L31,
# need to make the max for the uz datatype be 224.0 for accuracy reasons.
FP8_MAX = triton.language.constexpr(
default_float8_info.max if default_eight_bit_dtype_torch !=
torch.float8_e4m3fnuz else 224.0)
class MetaData:
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
eight_bit = False
layout = None
return_encoded_softmax = False
eight_bit_dtype_triton = default_eight_bit_dtype_triton
eight_bit_dtype_torch = default_eight_bit_dtype_torch
output_dtype = None
# Note about layouts:
#
# thd - [num_tokens, num_heads, head_size]
# bshd - [batch_size, seq_len, num_heads, head_size]
# bhsd - [batch_size, num_heads, seq_len, head_size]
#
# This is for each tensor, all tensors must have same layout.
# Q can have num_heads and seq_len differ from from K and V,
# however K and V must agree on this.
#
# Notes about varlen and bias:
# Only one or the other is implemented, meaning can't combine
# both varlen and bias right now.
#
# Note about quantization:
# Only 8-bit quantization supported (for now) and specifically fp8.
# Scales must be tensors.
# o_scale: This is 'output scaling', but comes from parameter called
# 'input_scale', this is applied to the output from the kernel.
# o_scale should be None if none of the other quantization parameters
# are used.
#
# NOTE: Object is in a tentatively good state after initialized, however,
# to verify, call check_args(q,k,v,o) where o is the output tensor.
def __init__(
self,
sm_scale=1.0,
layout=None, # layout can be 'bshd', 'bhsd', or 'thd'
output_dtype=None,
max_seqlens_q=0,
max_seqlens_k=0,
# varlen params
cu_seqlens_q=None, # only 'thd' layout supported for varlen
cu_seqlens_k=None,
# quant params
q_descale=None,
k_descale=None,
v_descale=None,
p_scale=None,
o_scale=None,
# bias params
bias=None, # varlen not implemented for bias
seqlen_q=None,
seqlen_k=None,
# alibi params
alibi_slopes=None,
alibi_batch=None,
alibi_nheads=None,
# causal
causal=None,
):
self.sm_scale = sm_scale
self.output_dtype = output_dtype
self.max_seqlens_q = max_seqlens_q
self.max_seqlens_k = max_seqlens_k
self.layout = layout
if cu_seqlens_q is not None or cu_seqlens_k is not None:
assert cu_seqlens_q is not None and cu_seqlens_k is not None
assert layout is None or layout not in [
'bshd', 'bhsd'
], "Varlen only implemented for thd layout"
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale]
if any(x is not None for x in quant_params):
p_descale = 1.0 / p_scale if p_scale is not None else None
self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale)
if bias is not None:
self.need_bias(bias, seqlen_q, seqlen_k)
if alibi_slopes is not None:
self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads)
if causal is not None and causal:
self.need_causal()
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(
cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(),
self.max_seqlens_q)
self.max_seqlens_k = max(
cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(),
self.max_seqlens_k)
def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale):
self.eight_bit = True
self.q_descale = q_descale
self.k_descale = k_descale
self.v_descale = v_descale
self.p_scale = p_scale
self.p_descale = p_descale
self.o_scale = o_scale
self.use_p_scale = (p_scale is not None) and (
p_descale is not None) and (v_descale is not None)
self.eight_bit_kv = ((q_descale is None) and (k_descale is not None)
and (v_descale is not None))
self.eight_bit_dtype_torch = default_eight_bit_dtype_torch
def need_bias(self, bias, seqlen_q, seqlen_k):
assert bias is not None
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, self)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
if self.eight_bit:
if self.eight_bit_kv:
assert (v.dtype == k.dtype
and k.dtype == self.eight_bit_dtype_torch)
assert q.dtype != k.dtype
assert (self.v_descale is not None) and (self.k_descale
is not None)
else:
assert (q.dtype == k.dtype and q.dtype == v.dtype
and q.dtype == self.eight_bit_dtype_torch)
assert (self.q_descale
is not None) and (self.k_descale
is not None) and (self.v_descale
is not None)
if self.use_p_scale:
assert (self.p_scale is not None) and (self.p_descale
is not None)
else:
assert (q.dtype == k.dtype) and (q.dtype == v.dtype)
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
@triton.jit @triton.jit
...@@ -243,85 +40,40 @@ def max_fn(x, y): ...@@ -243,85 +40,40 @@ def max_fn(x, y):
return tl.math.max(x, y) return tl.math.max(x, y)
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit @triton.jit
def masked_load(ptrs, offset_first, offset_second, boundary_first, def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
boundary_second): ms = tl.arange(0, m)
if offset_first is not None and offset_second is not None: ns = tl.arange(0, n)
mask = (offset_first[:, None] < boundary_first) & \ return philox_offset + ms[:, None] * stride + ns[None, :]
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit @triton.jit
def compute_alibi_block(alibi_slope, def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
seqlen_q, rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
seqlen_k, stride).to(tl.uint32)
offs_m, # TODO: use tl.randint for better performance
offs_n, return tl.rand(philox_seed, rng_offsets)
transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to
# the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is
# masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that
# spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5,
# offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] =
# [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q -
offs_n[None, :])
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): @triton.jit
q_idx = torch.arange(seqlen_q, dtype=torch.int32, def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
k_idx = torch.arange(seqlen_k, dtype=torch.int32, stride)
device="cuda").unsqueeze(0) # (1, N_CTX_K) rng_keep = rng_output > dropout_p
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - return rng_keep
k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
@triton.jit @triton.jit
def quant_fp8(x, scale): def load_fn(block_ptr, first, second, pad):
x *= scale if first and second:
x = tl.clamp(x, FP8_MIN, FP8_MAX) tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
return x elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit @triton.jit
...@@ -330,68 +82,61 @@ def _attn_fwd_inner( ...@@ -330,68 +82,61 @@ def _attn_fwd_inner(
l_i, l_i,
m_i, m_i,
q, q,
k_ptrs, K_block_ptr,
v_ptrs, V_block_ptr,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
actual_seqlen_k, actual_seqlen_k,
actual_seqlen_q, dropout_p,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_sm_ptrs, encoded_softmax_block_ptr,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
alibi_slope, bias_ptr,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr, OFFS_N: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr, PRE_LOAD_V: tl.constexpr,
SHOULD_MASK_STEPS: tl.constexpr, MASK_STEPS: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
USE_PADDED_HEAD: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, PADDED_HEAD: tl.constexpr,
QK_SCALE: tl.constexpr, USE_FP8: tl.constexpr,
IS_EIGHT_BIT_GEMM: tl.constexpr, qk_scale,
USE_P_SCALE: tl.constexpr, p_descale,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
# loop over k, v, and update accumulator # loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N): for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if # For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range. # we load all BLOCK_N. For others, the blocks are all within range.
k_offs_n = start_n + tl.arange(0, k = load_fn(
BLOCK_N) if SHOULD_MASK_STEPS else None K_block_ptr,
k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) PADDED_HEAD,
k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, MASK_STEPS and (n_extra_tokens != 0),
actual_seqlen_k) "zero",
if SHOULD_PRE_LOAD_V: )
# We can use the same offsets as k, just with dims transposed. if PRE_LOAD_V:
v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, v = load_fn(
IS_ACTUAL_BLOCK_DMODEL) V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need # We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n # to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block. # TODO: This can be optimized to only be true for the padded block.
if SHOULD_MASK_STEPS: # noqa: SIM102 if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to # If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size # mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not # a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# is_modulo_mn. last step might get wasted but that is okay. # if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case. # check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], boundary_m = tl.full([BLOCK_M],
...@@ -404,107 +149,112 @@ def _attn_fwd_inner( ...@@ -404,107 +149,112 @@ def _attn_fwd_inner(
causal_boundary = start_n + offs_n_causal causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf")) qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ---- # -- compute qk ----
if IS_EIGHT_BIT_GEMM: qk += tl.dot(q, k)
qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * if USE_FP8:
QK_SCALE) qk *= qk_scale
else: if bias_ptr is not None:
if IS_EIGHT_BIT_KV: bias = load_fn(bias_ptr, False, MASK_STEPS
k = (k * k_descale).to(q.type.element_ty) and (n_extra_tokens != 0), "zero")
qk += (tl.dot(q, k) * QK_SCALE) # While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
if bias_ptrs is not None: # scale factor of log2(e) which we must also multiply the bias with.
bias_offs_n = start_n + tl.arange( qk += bias * 1.44269504089
0, BLOCK_N) if SHOULD_MASK_STEPS else None
bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q,
actual_seqlen_k)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an
# additional scale factor of log2(e) which we must also multiply
# the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q,
actual_seqlen_k,
global_m_positions,
global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None] qk = qk - m_ij[:, None]
p = tl.math.exp2(qk) p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout # CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
if SHOULD_RETURN_ENCODED_SOFTMAX: if ENABLE_DROPOUT:
tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator -- # -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij) alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
if not SHOULD_PRE_LOAD_V: if not PRE_LOAD_V:
v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, v = load_fn(
IS_ACTUAL_BLOCK_DMODEL) V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i # -- update m_i and l_i
l_i = l_i * alpha + l_ij l_i = l_i * alpha + l_ij
# update m_i and l_i # update m_i and l_i
m_i = m_ij m_i = m_ij
if IS_EIGHT_BIT_GEMM: if USE_FP8:
if USE_P_SCALE: p *= p_descale
p = quant_fp8(p, p_scale).to(QUANT_DTYPE)
acc += tl.dot(p, v) acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
else:
# v is in eight_bit but p is not, we want the gemm in p's type V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
acc += tl.dot(p, v.to(p.type.element_ty)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
else: if bias_ptr is not None:
if IS_EIGHT_BIT_KV: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
v = (v * v_descale).to(p.type.element_ty) if RETURN_ENCODED_SOFTMAX:
acc += tl.dot(p.to(v.type.element_ty), v) encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += BLOCK_N
return acc, l_i, m_i return acc, l_i, m_i
def get_cdna_autotune_configs(): def get_cdna_autotune_configs():
return [ return [
triton.Config(
{
'BLOCK_M': 256,
'BLOCK_N': 64,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 128, 'BLOCK_N': 128,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 256,
'BLOCK_N': 64, 'BLOCK_N': 128,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=8),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 64, 'BLOCK_N': 64,
'waves_per_eu': 3, 'waves_per_eu': 1,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
...@@ -512,168 +262,141 @@ def get_cdna_autotune_configs(): ...@@ -512,168 +262,141 @@ def get_cdna_autotune_configs():
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 64, 'BLOCK_N': 64,
'waves_per_eu': 1, 'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': True
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 32, 'BLOCK_N': 64,
'waves_per_eu': 2, 'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 64,
'BLOCK_N': 32, 'BLOCK_N': 64,
'waves_per_eu': 4, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=8),
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 32,
'BLOCK_N': 32, 'BLOCK_N': 32,
'waves_per_eu': 2, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
# triton.Config(
# {
# "BLOCK_M": 16,
# "BLOCK_N": 16,
# "waves_per_eu": 1,
# "PRE_LOAD_V": False,
# },
# num_stages=1,
# num_warps=4,
# ),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 32,
'waves_per_eu': 4, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 32,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
triton.Config( triton.Config(
{ {
'BLOCK_M': 16, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 16,
'waves_per_eu': 4, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
triton.Config( triton.Config(
{ {
'BLOCK_M': 16, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 16,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=2),
# Fall-back config.
triton.Config(
{
'BLOCK_M': 16,
'BLOCK_N': 16,
'waves_per_eu': 1,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
], [ # Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', # triton.Config(
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' # {
] # 'BLOCK_M': 16,
# 'BLOCK_N': 16,
# 'waves_per_eu': 4,
def get_general_autotune_configs(): # 'PRE_LOAD_V': False
return [ # },
triton.Config( # num_stages=1,
{ # num_warps=2),
'BLOCK_M': 128, # triton.Config(
'BLOCK_N': 128, # {
'SHOULD_PRE_LOAD_V': False, # 'BLOCK_M': 16,
'GRID_CU_MULTIP': 2 # 'BLOCK_N': 16,
}, # 'waves_per_eu': 2,
num_stages=1, # 'PRE_LOAD_V': False
num_warps=4), # },
triton.Config( # num_stages=1,
{ # num_warps=2),
'BLOCK_M': 128, # # Fall-back config.
'BLOCK_N': 64, # triton.Config(
'SHOULD_PRE_LOAD_V': False, # {
'GRID_CU_MULTIP': 2 # 'BLOCK_M': 16,
}, # 'BLOCK_N': 16,
num_stages=1, # 'waves_per_eu': 1,
num_warps=4), # 'PRE_LOAD_V': False
triton.Config( # },
{ # num_stages=1,
'BLOCK_M': 128, # num_warps=2),
'BLOCK_N': 32, ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def has_cdna_target():
ROCM_CDNA_TARGETS = ["gfx942", "gfx90a", "gfx908"]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_CDNA_TARGETS
def is_rocm_cdna():
return current_platform.is_rocm() and has_cdna_target()
def get_autotune_configs(): def get_autotune_configs():
if is_rocm_cdna(): if on_gfx1x():
return get_cdna_autotune_configs()
elif current_platform.is_rocm():
return get_rdna_autotune_configs() return get_rdna_autotune_configs()
else: else:
return get_general_autotune_configs() return get_cdna_autotune_configs()
autotune_configs, autotune_keys = get_autotune_configs() autotune_configs, autotune_keys = get_autotune_configs()
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.autotune( @triton.autotune(
configs=autotune_configs, configs=autotune_configs,
key=autotune_keys, key=autotune_keys,
use_cuda_graph=True,
) )
@triton.jit @triton.jit
def attn_fwd( def attn_fwd(
...@@ -681,7 +404,13 @@ def attn_fwd( ...@@ -681,7 +404,13 @@ def attn_fwd(
K, K,
V, V,
bias, bias,
SM_SCALE: tl.constexpr, sm_scale,
q_scale,
k_scale,
v_scale,
p_scale,
p_descale,
o_descale,
L, L,
Out, Out,
stride_qz: tl.int64, stride_qz: tl.int64,
...@@ -704,70 +433,44 @@ def attn_fwd( ...@@ -704,70 +433,44 @@ def attn_fwd(
stride_bh: tl.int64, stride_bh: tl.int64,
stride_bm: tl.int64, stride_bm: tl.int64,
stride_bn: tl.int64, stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
dropout_p,
philox_seed, philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
alibi_slopes,
HQ: tl.constexpr, HQ: tl.constexpr,
HK: tl.constexpr, HK: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
USE_FP8: tl.constexpr,
USE_FP8_OUT: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr, PRE_LOAD_V: tl.constexpr,
USE_BIAS: tl.constexpr, BIAS_TYPE: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
USE_ALIBI: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr, FP8_MIN: tl.constexpr = float8_info.min,
USE_P_SCALE: tl.constexpr, FP8_MAX: tl.constexpr = float8_info.max,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
start_m = tl.program_id(0)
if o_descale_ptr is not None: off_h_q = tl.program_id(1)
o_descale = tl.load(o_descale_ptr) off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
start_m: tl.int64 = tl.program_id(0) offs_n = tl.arange(0, BLOCK_N)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN: if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be # We have a one-size-fits-all grid in id(0). Some seqlens might be too
# too small for all start_m so for those we return early. # small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q: if start_m * BLOCK_M > seqlen_q:
continue_condition = False return
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
...@@ -777,598 +480,499 @@ def attn_fwd( ...@@ -777,598 +480,499 @@ def attn_fwd(
seqlen_q = MAX_SEQLENS_Q seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K seqlen_k = MAX_SEQLENS_K
if continue_condition: # Now we compute whether we need to exit early due to causal masking.
# Now we compute whether we need to exit early due to causal # This is because for seqlen_q > seqlen_k, M rows of the attn scores
# masking. This is because for seqlen_q > seqlen_k, M rows of the # are completely masked, resulting in 0s written to the output, and
# attn scores are completely masked, resulting in 0s written to the # inf written to LSE. We don't need to do any GEMMs in this case.
# output, and inf written to LSE. We don't need to do any GEMMs in # This block of code determines what N is, and if this WG is operating
# this case. This block of code determines what N is, and if this # on those M rows.
# WG is operating on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
n_blocks = cdiv_fn(seqlen_k, BLOCK_N) if IS_CAUSAL:
if (IS_CAUSAL): # If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means
# If seqlen_q != seqlen_k, attn scores are rectangular which # the causal mask boundary is bottom right aligned, and ends at either
# means the causal mask boundary is bottom right aligned, and # the top edge (seqlen_q < seqlen_k) or left edge.
# ends at either the top edge (seqlen_q < seqlen_k) or left # This captures the decrease in n_blocks if we have a rectangular attn
# edge. This captures the decrease in n_blocks if we have a # matrix
# rectangular attn matrix n_blocks_seqlen = cdiv_fn(
n_blocks_seqlen = cdiv_fn( (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only
# This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks = min(n_blocks, n_blocks_seqlen)
# n_blocks # If we have no blocks after adjusting for seqlen deltas, this WG is
n_blocks = min(n_blocks, n_blocks_seqlen) # part of the blocks that are all 0. We exit early.
# If we have no blocks after adjusting for seqlen deltas, this if n_blocks <= 0:
# WG is part of the blocks that are all 0. We exit early. o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
if n_blocks <= 0: off_h_q * stride_oh)
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr(
cu_seqlens_q_start * stride_om) base=Out + o_offset,
o_ptrs = (o_offset + offs_m[:, None] * stride_om + shape=(seqlen_q, BLOCK_DMODEL),
offs_d[None, :] * stride_on) strides=(stride_om, stride_on),
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) offsets=(start_m * BLOCK_M, 0),
o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( block_shape=(BLOCK_M, BLOCK_DMODEL),
[BLOCK_M, BLOCK_DMODEL]) order=(1, 0),
# We still need to write 0s to the result )
tl.store(o_ptrs, acc, mask=o_ptrs_mask) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# The tensor allocated for L is based on MAX_SEQLENS_Q as # We still need to write 0s to the result
# that is statically known. # tl.store(O_block_ptr,
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + # acc.to(Out.type.element_ty), boundary_check=(0,1))
off_h_q * MAX_SEQLENS_Q + offs_m) # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# We store inf to LSE, not -inf because in the bwd pass, # + offs_m
# we subtract this from qk which makes it -inf, such that # We store inf to LSE, not -inf because in the bwd pass,
# exp(qk - inf) = 0 for these masked blocks. # we subtract this
l_value = tl.full([BLOCK_M], # from qk which makes it -inf, such that exp(qk - inf) = 0
value=float("inf"), # for these masked blocks.
dtype=tl.float32) # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
l_ptrs_mask = offs_m < MAX_SEQLENS_Q # tl.store(l_ptrs, l)
tl.store(l_ptrs, l_value, mask=l_ptrs_mask) # TODO: Should dropout and return encoded softmax be handled here?
# TODO: Should dropout and return encoded softmax be return
# handled here too?
continue_condition = False # If MQA / GQA, set the K and V head offsets appropriately.
# return GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
if continue_condition:
# If MQA / GQA, set the K and V head offsets appropriately. n_extra_tokens = 0
GROUP_SIZE: tl.constexpr = HQ // HK if seqlen_k < BLOCK_N:
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = BLOCK_N - seqlen_k
n_extra_tokens = 0 elif seqlen_k % BLOCK_N:
if seqlen_k < BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N
n_extra_tokens = BLOCK_N - seqlen_k padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N # Compute pointers for all the tensors used in this kernel.
USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL q_offset = (off_z * stride_qz + off_h_q * stride_qh +
!= BLOCK_DMODEL) cu_seqlens_q_start * stride_qm)
Q_block_ptr = tl.make_block_ptr(
# Compute pointers for all the tensors used in this kernel. base=Q + q_offset,
q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
cu_seqlens_q_start * stride_qm) strides=(stride_qm, stride_qk),
q_ptrs = (q_offset + offs_m[:, None] * stride_qm + offsets=(start_m * BLOCK_M, 0),
offs_d[None, :] * stride_qk) block_shape=(BLOCK_M, BLOCK_DMODEL),
k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + order=(1, 0),
cu_seqlens_k_start * stride_kn) )
k_ptrs = (k_offset + offs_d[:, None] * stride_kk + k_offset = (off_z * stride_kz + off_h_k * stride_kh +
offs_n[None, :] * stride_kn) cu_seqlens_k_start * stride_kn)
v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + K_block_ptr = tl.make_block_ptr(
cu_seqlens_k_start * stride_vk) base=K + k_offset,
v_ptrs = (v_offset + offs_n[:, None] * stride_vk + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
offs_d[None, :] * stride_vn) strides=(stride_kk, stride_kn),
# Compute pointers for all scale tensors used in this kernel. offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( order=(0, 1),
not IS_EIGHT_BIT_KV) )
if IS_EIGHT_BIT: v_offset = (off_z * stride_vz + off_h_k * stride_vh +
if k_descale_has_singleton: cu_seqlens_k_start * stride_vk)
k_descale_ptrs = k_descale_ptr V_block_ptr = tl.make_block_ptr(
else: base=V + v_offset,
k_descale_ptrs = k_descale_ptr + off_h_k shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
if v_descale_has_singleton: offsets=(0, 0),
v_descale_ptrs = v_descale_ptr block_shape=(BLOCK_N, BLOCK_DMODEL),
else: order=(1, 0),
v_descale_ptrs = v_descale_ptr + off_h_k )
if BIAS_TYPE != 0:
if not IS_EIGHT_BIT_KV: bias_ptr = tl.make_block_ptr(
if q_descale_has_singleton: base=bias + off_h_q * stride_bh,
q_descale_ptrs = q_descale_ptr shape=(seqlen_q, seqlen_k),
else: strides=(stride_bm, stride_bn),
q_descale_ptrs = q_descale_ptr + off_h_q offsets=(start_m * BLOCK_M, 0),
if USE_P_SCALE: block_shape=(BLOCK_M, BLOCK_N),
if p_descale_has_singleton: order=(1, 0),
p_scale_ptrs = p_scale_ptr )
p_descale_ptrs = p_descale_ptr else:
else: bias_ptr = None
p_scale_ptrs = p_scale_ptr + off_h_q if ENABLE_DROPOUT:
p_descale_ptrs = p_descale_ptr + off_h_q batch_philox_offset = philox_offset_base \
+ (off_z * HQ + off_h_q) \
if USE_BIAS: * seqlen_q * seqlen_k
bias_offset = off_h_q * stride_bh else:
bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm + batch_philox_offset = 0
offs_n[None, :] * stride_bn) # We can ask to return the dropout mask without actually doing any dropout.
else: # In this case, we return an invalid pointer so indicate the mask is not i
bias_ptrs = None # valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if USE_ALIBI: if RETURN_ENCODED_SOFTMAX:
a_offset = off_z * stride_az + off_h_q * stride_ah encoded_softmax_block_ptr = tl.make_block_ptr(
alibi_slope = tl.load(alibi_slopes + a_offset) base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
else: shape=(seqlen_q, seqlen_k),
alibi_slope = None strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
batch_philox_offset = 0 block_shape=(BLOCK_M, BLOCK_N),
# We can ask to return the dropout mask without doing any order=(1, 0),
# dropout. In this case, we return an invalid pointer so )
# indicate the mask is not valid. else:
if SHOULD_RETURN_ENCODED_SOFTMAX: encoded_softmax_block_ptr = 0
encoded_sm_base = (encoded_softmax + # initialize pointer to m and l
off_h_q * seqlen_q * seqlen_k) m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
encoded_sm_ptrs = (encoded_sm_base + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
offs_m[:, None] * seqlen_k + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
offs_n[None, :]) # scale sm_scale by log_2(e) and use 2^x in the loop as we do not
else: # have native e^x support in HW.
encoded_sm_ptrs = None qk_scale = sm_scale * 1.44269504089
# initialize pointer to m and l # Q is loaded once at the beginning and shared by all N blocks.
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) q = load_fn(Q_block_ptr, True, padded_head, "zero")
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) if not USE_FP8:
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do acc_scale = 1.0
# not have native e^x support in HW. else:
QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 qk_scale *= q_scale * k_scale
# Q is loaded once at the beginning and shared by all N blocks. acc_scale = p_scale * v_scale
q_ptrs_mask = offs_m[:, None] < seqlen_q
if USE_PADDED_HEAD: # Here we compute how many full and masked blocks we have.
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] padded_block_k = n_extra_tokens != 0
< IS_ACTUAL_BLOCK_DMODEL) is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
if IS_EIGHT_BIT: # Additionally there might be one more due to dissimilar seqlens.
k_descale = tl.load(k_descale_ptrs) masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
v_descale = tl.load(v_descale_ptrs) else:
q_descale = None if IS_EIGHT_BIT_KV else tl.load( # Padding on Q does not need to be masked in the FA loop.
q_descale_ptrs) masked_blocks = padded_block_k
if USE_P_SCALE: # if IS_CAUSAL, not is_modulo_mn does not always result in an additional
p_scale = tl.load(p_scale_ptrs) # block. In this case we might exceed n_blocks so pick the min.
p_descale = tl.load(p_descale_ptrs) masked_blocks = min(masked_blocks, n_blocks)
else: n_full_blocks = n_blocks - masked_blocks
p_scale = None block_min = 0
p_descale = None block_max = n_blocks * BLOCK_N
else: # Compute for full blocks. Here we set causal to false regardless of its
q_descale = None # value because there is no masking. Similarly we do not need padding.
k_descale = None if n_full_blocks > 0:
v_descale = None block_max = (n_blocks - masked_blocks) * BLOCK_N
p_scale = None acc, l_i, m_i = _attn_fwd_inner(
p_descale = None acc,
# Here we compute how many full and masked blocks we have. l_i,
padded_block_k = n_extra_tokens != 0 m_i,
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) q,
if IS_CAUSAL: K_block_ptr,
# There are always at least BLOCK_M // BLOCK_N masked V_block_ptr,
# blocks. Additionally there might be one more due to start_m,
# dissimilar seqlens. seqlen_k,
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) dropout_p,
else: philox_seed,
# Padding on Q does not need to be masked in the FA loop. batch_philox_offset,
masked_blocks = padded_block_k encoded_softmax_block_ptr,
# if IS_CAUSAL, not is_modulo_mn does not always result in an # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
# additional block. In this case we might exceed n_blocks so block_min,
# pick the min. block_max,
masked_blocks = min(masked_blocks, n_blocks) 0,
n_full_blocks = n_blocks - masked_blocks 0,
block_min = 0 0,
block_max = n_blocks * BLOCK_N bias_ptr,
# Compute for full blocks. Here we set causal to false # IS_CAUSAL, ....
# regardless of its actual value because there is no masking. False,
# Similarly we do not need padding. BLOCK_M,
if n_full_blocks > 0: BLOCK_DMODEL,
block_max = (n_blocks - masked_blocks) * BLOCK_N BLOCK_N,
acc, l_i, m_i = _attn_fwd_inner( offs_m,
acc, offs_n,
l_i, # _, MASK_STEPS, ...
m_i, PRE_LOAD_V,
q, False,
k_ptrs, ENABLE_DROPOUT,
v_ptrs, RETURN_ENCODED_SOFTMAX,
bias_ptrs, padded_head,
stride_kn, USE_FP8,
stride_vk, qk_scale,
stride_bn, p_descale,
start_m, )
seqlen_k, block_min = block_max
seqlen_q, block_max = n_blocks * BLOCK_N
philox_seed,
batch_philox_offset, tl.debug_barrier()
encoded_sm_ptrs, # Remaining blocks, if any, are full / not masked.
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ if masked_blocks > 0:
block_min, offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
block_max, K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
0, V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
0, if bias_ptr is not None:
0, bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
alibi_slope, if RETURN_ENCODED_SOFTMAX:
q_descale, encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
k_descale, (0, n_full_blocks))
v_descale, acc, l_i, m_i = _attn_fwd_inner(
p_scale, acc,
# IS_CAUSAL, .... l_i,
False, m_i,
BLOCK_M, q,
BLOCK_DMODEL, K_block_ptr,
BLOCK_N, V_block_ptr,
offs_m, start_m,
offs_n, seqlen_k,
# _, SHOULD_MASK_STEPS, ... dropout_p,
SHOULD_PRE_LOAD_V, philox_seed,
False, batch_philox_offset,
SHOULD_RETURN_ENCODED_SOFTMAX, encoded_softmax_block_ptr,
USE_PADDED_HEAD, block_min,
IS_ACTUAL_BLOCK_DMODEL, block_max,
QK_SCALE, offs_n_causal,
IS_EIGHT_BIT_GEMM, masked_blocks,
USE_P_SCALE, n_extra_tokens,
IS_EIGHT_BIT_KV, bias_ptr,
QUANT_DTYPE) IS_CAUSAL,
block_min = block_max BLOCK_M,
block_max = n_blocks * BLOCK_N BLOCK_DMODEL,
BLOCK_N,
tl.debug_barrier() offs_m,
# Remaining blocks, if any, are full / not masked. offs_n,
if (masked_blocks > 0): # _, MASK_STEPS, ...
if IS_CAUSAL: PRE_LOAD_V,
offs_n_causal = offs_n + (seqlen_q - seqlen_k) True,
else: ENABLE_DROPOUT,
offs_n_causal = 0 RETURN_ENCODED_SOFTMAX,
k_ptrs += n_full_blocks * BLOCK_N * stride_kn padded_head,
v_ptrs += n_full_blocks * BLOCK_N * stride_vk USE_FP8,
if USE_BIAS: qk_scale,
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn p_descale,
if SHOULD_RETURN_ENCODED_SOFTMAX: )
encoded_sm_ptrs += n_full_blocks * BLOCK_N # epilogue
acc, l_i, m_i = _attn_fwd_inner(
acc, if USE_FP8:
l_i, acc *= acc_scale
m_i, acc = acc / l_i[:, None]
q, if ENABLE_DROPOUT:
k_ptrs, acc = acc / (1 - dropout_p)
v_ptrs, # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
bias_ptrs, # then we have one block with a row of all NaNs which come from computing
stride_kn, # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
stride_vk, # and store 0s where there are NaNs as these rows should've been zeroed out.
stride_bn, end_m_idx = (start_m + 1) * BLOCK_M
start_m, start_m_idx = start_m * BLOCK_M
seqlen_k, causal_start_idx = seqlen_q - seqlen_k
seqlen_q, if USE_FP8_OUT:
philox_seed, acc *= o_descale
batch_philox_offset, acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
encoded_sm_ptrs, acc = acc.to(Out.type.element_ty)
block_min, if IS_CAUSAL: # noqa: SIM102
block_max, if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
offs_n_causal, out_mask_boundary = tl.full((BLOCK_DMODEL, ),
masked_blocks, causal_start_idx,
n_extra_tokens, dtype=tl.int32)
alibi_slope, mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
q_descale, out_ptrs_mask = (mask_m_offsets[:, None]
k_descale, >= out_mask_boundary[None, :])
v_descale, z = tl.zeros((1, ), tl.float32)
p_scale, acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
IS_CAUSAL, # write back LSE
BLOCK_M, # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
BLOCK_DMODEL, # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
BLOCK_N, # few rows. This is only true for the last M block. For others,
offs_m, # overflow_size will be -ve
offs_n, # overflow_size = end_m_idx - seqlen_q
# _, SHOULD_MASK_STEPS, ... # if overflow_size > 0:
SHOULD_PRE_LOAD_V, # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
True, # # This is a > check because mask being 0 blocks the store.
SHOULD_RETURN_ENCODED_SOFTMAX, # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
USE_PADDED_HEAD, # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
IS_ACTUAL_BLOCK_DMODEL, # else:
QK_SCALE, # tl.store(l_ptrs, m_i + tl.math.log2(l_i))
IS_EIGHT_BIT_GEMM,
USE_P_SCALE, # write back O
IS_EIGHT_BIT_KV, o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
QUANT_DTYPE) off_h_q * stride_oh)
O_block_ptr = tl.make_block_ptr(
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: base=Out + o_offset,
if USE_P_SCALE: shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
acc *= p_descale strides=(stride_om, stride_on),
acc *= v_descale offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
# epilogue order=(1, 0),
# This helps the compiler do Newton Raphson on l_i vs on acc )
# which is much larger. # Need boundary check on this to make sure the padding from the
l_recip = 1 / l_i[:, None] # Q and KV tensors in both dims are not part of what we store back.
acc = acc * l_recip # TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
# If seqlen_q > seqlen_k but the delta is not a multiple of
# BLOCK_M, then we have one block with a row of all NaNs which
# come from computing softmax over a row of all def check_args(
# -infs (-inf - inf = NaN). We check for that here and store 0s q,
# where there are NaNs as these rows should've been zeroed out. k,
end_m_idx = (start_m + 1) * BLOCK_M v,
start_m_idx = start_m * BLOCK_M o,
causal_start_idx = seqlen_q - seqlen_k varlen=True,
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 max_seqlens=None,
if o_descale_ptr is not None: cu_seqlens_q=None,
acc = quant_fp8(acc, o_descale) cu_seqlens_k=None,
):
acc = acc.to(Out.type.element_ty) assert q.dim() == k.dim() and q.dim() == v.dim()
if IS_CAUSAL: # noqa: SIM102 if varlen:
if (causal_start_idx > start_m_idx assert q.dim() == 3
and causal_start_idx < end_m_idx): total_q, nheads_q, head_size = q.shape
out_mask_boundary = tl.full((BLOCK_DMODEL, ), total_k, nheads_k, _ = k.shape
causal_start_idx, assert cu_seqlens_q is not None
dtype=tl.int32) assert cu_seqlens_k is not None
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) assert len(cu_seqlens_q) == len(cu_seqlens_k)
out_ptrs_mask = (mask_m_offsets[:, None] else:
>= out_mask_boundary[None, :]) assert q.dim() == 4
z = tl.zeros((1, ), tl.float32) batch, nheads_q, seqlen_q, head_size = q.shape
acc = tl.where(out_ptrs_mask, acc, _, nheads_k, seqlen_k, _ = k.shape
z.to(acc.type.element_ty)) assert max_seqlens > 0
# write back LSE assert k.shape == v.shape
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
off_h_q * MAX_SEQLENS_Q + offs_m) # TODO: Change assert if we support qkl f8 and v f16
# If seqlen_q not multiple of BLOCK_M, we need to mask out the assert q.dtype == k.dtype and q.dtype == v.dtype
# last few rows. This is only true for the last M block. assert head_size <= 256
# For others, overflow_size will be -ve assert o.shape == q.shape
overflow_size = end_m_idx - seqlen_q assert (nheads_q % nheads_k) == 0
if overflow_size > 0:
boundary = tl.full((BLOCK_M, ),
BLOCK_M - overflow_size,
dtype=tl.int32)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh +
cu_seqlens_q_start * stride_om)
o_ptrs = (o_offset + offs_m[:, None] * stride_om +
offs_d[None, :] * stride_on)
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1)
if overflow_size > 0:
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q)
if USE_PADDED_HEAD:
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :]
< IS_ACTUAL_BLOCK_DMODEL)
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask)
def get_shape_from_layout(q, k, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
if metadata.layout == 'thd':
nheads_q, nheads_k = q.shape[1], k.shape[1]
head_size = q.shape[-1]
batch = metadata.num_contexts
elif metadata.layout == 'bhsd':
batch, nheads_q, _, head_size = q.shape
nheads_k = k.shape[1]
elif metadata.layout == 'bshd':
batch, _, nheads_q, head_size = q.shape
nheads_k = k.shape[2]
return batch, nheads_q, nheads_k, head_size
def get_strides_from_layout(q, k, v, o, metadata):
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
STRIDE_PERMUTATIONS = {
'thd': (None, 1, 0, 2),
'bhsd': (0, 1, 2, 3),
'bshd': (0, 2, 1, 3),
}
perm = STRIDE_PERMUTATIONS[metadata.layout]
stride = lambda x, p: (0 if p is None else x.stride(p))
strides = lambda x: (stride(x, p) for p in perm)
return tuple(strides(x) for x in [q, k, v, o])
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, o, metadata: MetaData): def forward(
# NOTE: a large bias tensor leads to overflow during pointer arithmetic ctx,
if (metadata.bias is not None): q,
assert (metadata.bias.numel() < 2**31) k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
fp8_scales=None,
fp8_out_scale=None,
):
if fp8_scales is not None:
use_fp8 = True
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
float8 = current_platform.fp8_dtype()
def check_and_convert(t, scale):
if t.dtype != float8:
descale = 1.0 / scale
ts = (t * descale).clamp(min=float8_info.min,
max=float8_info.max)
return ts.to(float8)
else:
return t
if o is None: q = check_and_convert(q, q_scale)
if metadata.eight_bit: k = check_and_convert(k, k_scale)
o = torch.empty_like( v = check_and_convert(v, v_scale)
q, else:
dtype=metadata.output_dtype if metadata.output_dtype use_fp8 = False
is not None else metadata.eight_bit_dtype_torch) q_scale = k_scale = v_scale = p_scale = 1.0
else:
o = torch.empty_like(q, dtype=q.dtype)
metadata.check_args(q, k, v, o) if o is None:
o = torch.empty_like(q, dtype=v.dtype)
batch, nheads_q, nheads_k, head_size = get_shape_from_layout( check_args(
q, k, metadata) q,
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( k,
q, k, v, o, metadata) v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32. # Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length() unpadded_head_dims = {32, 64, 128, 256}
# Smallest head_dim supported is 16. If smaller, the tile in the if head_size not in unpadded_head_dims:
# kernel is padded - there is no padding in memory for any dims. padded_d_model = None
padded_d_model = max(padded_d_model, 16) for i in unpadded_head_dims:
if i > head_size:
# encoded_softmax is used to validate dropout behavior vs the padded_d_model = i
# PyTorch SDPA math backend reference. We zero this out to give a break
# consistent starting point and then populate it with the output of assert padded_d_model is not None
# softmax with the sign bit set according to the dropout mask.
# The resulting return allows this mask to be fed into the reference
# implementation for testing only. This return holds no useful output
# aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
device=q.device,
dtype=torch.float32)
else: else:
encoded_softmax = None padded_d_model = head_size
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), encoded_softmax = None
device=q.device,
dtype=torch.float32)
# Seed the RNG so we get reproducible results for testing. # Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52 philox_seed = 0x1BF52
philox_offset = 0x1D4B42 philox_offset = 0x1D4B42
if metadata.bias is not None: if bias is not None:
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), bias_strides = (
metadata.bias.stride(2), metadata.bias.stride(3)) bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else: else:
bias_strides = (0, 0, 0, 0) bias_strides = (0, 0, 0, 0)
if metadata.alibi_slopes is not None: p_descale = 1.0 / p_scale
alibi_strides = (metadata.alibi_slopes.stride(0), o_descale = 1.0 / fp8_out_scale.item(
metadata.alibi_slopes.stride(1)) ) if fp8_out_scale is not None else 1.0
else:
alibi_strides = (0, 0)
if metadata.eight_bit: arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
else:
q_descale = k_descale = p_scale = None
p_descale = v_descale = o_descale = None
# number of compute units available
NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[
'BLOCK_M']), nheads_q, batch)
attn_fwd[grid]( attn_fwd[grid](
q, q,
k, k,
v, v,
metadata.bias, bias,
metadata.sm_scale, sm_scale,
M, q_scale,
k_scale,
v_scale,
p_scale,
p_descale,
o_descale,
None,
o, o,
*q_strides, *q_strides,
*k_strides, *k_strides,
*v_strides, *v_strides,
*o_strides, *o_strides,
*bias_strides, *bias_strides,
*alibi_strides, cu_seqlens_q,
q_descale, cu_seqlens_k,
k_descale, dropout_p=0.0,
p_scale,
p_descale,
o_descale,
v_descale,
q_descale.numel() == 1 if q_descale is not None else False,
k_descale.numel() == 1 if k_descale is not None else False,
p_descale.numel() == 1 if p_descale is not None else False,
v_descale.numel() == 1 if v_descale is not None else False,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
philox_seed=philox_seed, philox_seed=philox_seed,
philox_offset_base=philox_offset, philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax, encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HQ=nheads_q,
HK=nheads_k, HK=nheads_k,
IS_ACTUAL_BLOCK_DMODEL=head_size, ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_Q=arg_max_seqlens_q,
MAX_SEQLENS_K=metadata.max_seqlens_k, MAX_SEQLENS_K=arg_max_seqlens_k,
IS_CAUSAL=metadata.causal, IS_CAUSAL=causal,
VARLEN=metadata.varlen, VARLEN=True,
BLOCK_DMODEL=padded_d_model, BLOCK_DMODEL=padded_d_model,
USE_BIAS=metadata.bias is not None, BIAS_TYPE=0 if bias is None else 1,
USE_ALIBI=metadata.alibi_slopes is not None, ENABLE_DROPOUT=False,
SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, RETURN_ENCODED_SOFTMAX=False,
IS_EIGHT_BIT=metadata.eight_bit, USE_FP8=use_fp8,
USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, USE_FP8_OUT=fp8_out_scale is not None,
IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, )
NUM_CU=NUM_CU,
B=batch,
QUANT_DTYPE=metadata.eight_bit_dtype_triton)
ctx.grid = grid ctx.grid = grid
ctx.sm_scale = metadata.sm_scale ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size ctx.BLOCK_DMODEL = head_size
ctx.causal = metadata.causal ctx.causal = causal
ctx.alibi_slopes = metadata.alibi_slopes ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = metadata.return_encoded_softmax ctx.return_encoded_softmax = False
return o, encoded_softmax return o, encoded_softmax
triton_attention_rocm = _attention.apply triton_attention = _attention.apply
def scale_fp8(t, scale=None):
t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]),
scale)
return t_scaled.reshape(t.shape), scale_out
def maybe_quantize_fp8(t, scale):
eight_bit_dtype = current_platform.fp8_dtype()
if t.dtype != eight_bit_dtype:
t, _ = scale_fp8(t, scale)
return t
def check_and_maybe_quantize_qkv(q, k, v, fp8_scales):
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
q = maybe_quantize_fp8(q, q_scale)
k = maybe_quantize_fp8(k, k_scale)
v = maybe_quantize_fp8(v, v_scale)
return q, k, v
# query - [num_tokens, num_heads, head_size]
# key - [num_tokens, num_kv_heads, head_size]
# value - [num_tokens, num_kv_heads, head_size
# output - [num_tokens, num_heads, head_size]
def triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlens_q: int,
max_seqlens_k: int,
causal: bool = False,
sm_scale: float = 1.0,
bias: Optional[torch.Tensor] = None,
fp8_scales: Optional[tuple[float, ...]] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if fp8_scales is not None:
q_descale, k_descale, v_descale, p_scale = fp8_scales
else:
q_descale = k_descale = v_descale = p_scale = None
attn_metadata = MetaData(sm_scale=sm_scale,
max_seqlens_q=max_seqlens_q,
max_seqlens_k=max_seqlens_k,
causal=causal,
bias=bias,
output_dtype=q.dtype,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=input_scale)
if fp8_scales is not None:
q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales)
return triton_attention_rocm(q, k, v, o, attn_metadata)
...@@ -98,6 +98,12 @@ def with_amdsmi_context(fn): ...@@ -98,6 +98,12 @@ def with_amdsmi_context(fn):
return wrapper return wrapper
@cache
def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@cache @cache
def on_mi250_mi300() -> bool: def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment