Unverified Commit 1b7cfd5a authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm][V0][Attention] Revert to the previous FA triton kernel (#18226)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent da4b69d0
...@@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -770,8 +770,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
and layer._v_scale and layer._prob_scale and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8") and self.kv_cache_dtype == "fp8")
full_scales = ( full_scales = (
layer._q_scale, layer._k_scale, layer._v_scale, layer._q_scale.item(), layer._k_scale.item(),
layer._prob_scale) if use_fp8_scales else None layer._v_scale.item(),
layer._prob_scale.item()) if use_fp8_scales else None
self.triton_attn_func( self.triton_attn_func(
query, query,
key, key,
......
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
Fused Attention Fused Attention
=============== ===============
This is a Triton implementation of the Flash Attention v2 algorithm This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
See https://tridao.me/publications/flash2/flash2.pdf (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Credits: Features supported:
AMD Triton kernels team
OpenAI kernel team
Currently only the forward kernel is supported, and contains these features:
1) Fwd with causal masking 1) Fwd with causal masking
2) Arbitrary Q and KV sequence lengths 2) Any sequence lengths without padding (currently fwd kernel only)
3) Arbitrary head sizes 3) Support for different sequence lengths for q and k
4) Multi and grouped query attention 4) Nested tensor API currently does not support dropout or bias.
5) Variable sequence lengths
6) ALiBi and matrix bias
7) FP8 support
""" Not currently supported:
from typing import Optional 1) Non power of two head dims
"""
import torch import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx1x
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] torch_dtype: tl.constexpr = torch.float16
default_eight_bit_dtype_triton = tl.float8e4b8
default_eight_bit_dtype_torch = current_platform.fp8_dtype()
default_float8_info = torch.finfo(default_eight_bit_dtype_torch)
FP8_MIN = triton.language.constexpr(default_float8_info.min)
# According to https://github.com/vllm-project/vllm/blob/main
# /csrc/quantization/utils.cuh#L31,
# need to make the max for the uz datatype be 224.0 for accuracy reasons.
FP8_MAX = triton.language.constexpr(
default_float8_info.max if default_eight_bit_dtype_torch !=
torch.float8_e4m3fnuz else 224.0)
class MetaData:
cu_seqlens_q = None
cu_seqlens_k = None
max_seqlens_q = 0
max_seqlens_k = 0
bias = None
alibi_slopes = None
causal = False
num_contexts = 0
varlen = False
eight_bit = False
layout = None
return_encoded_softmax = False
eight_bit_dtype_triton = default_eight_bit_dtype_triton
eight_bit_dtype_torch = default_eight_bit_dtype_torch
output_dtype = None
# Note about layouts:
#
# thd - [num_tokens, num_heads, head_size]
# bshd - [batch_size, seq_len, num_heads, head_size]
# bhsd - [batch_size, num_heads, seq_len, head_size]
#
# This is for each tensor, all tensors must have same layout.
# Q can have num_heads and seq_len differ from from K and V,
# however K and V must agree on this.
#
# Notes about varlen and bias:
# Only one or the other is implemented, meaning can't combine
# both varlen and bias right now.
#
# Note about quantization:
# Only 8-bit quantization supported (for now) and specifically fp8.
# Scales must be tensors.
# o_scale: This is 'output scaling', but comes from parameter called
# 'input_scale', this is applied to the output from the kernel.
# o_scale should be None if none of the other quantization parameters
# are used.
#
# NOTE: Object is in a tentatively good state after initialized, however,
# to verify, call check_args(q,k,v,o) where o is the output tensor.
def __init__(
self,
sm_scale=1.0,
layout=None, # layout can be 'bshd', 'bhsd', or 'thd'
output_dtype=None,
max_seqlens_q=0,
max_seqlens_k=0,
# varlen params
cu_seqlens_q=None, # only 'thd' layout supported for varlen
cu_seqlens_k=None,
# quant params
q_descale=None,
k_descale=None,
v_descale=None,
p_scale=None,
o_scale=None,
# bias params
bias=None, # varlen not implemented for bias
seqlen_q=None,
seqlen_k=None,
# alibi params
alibi_slopes=None,
alibi_batch=None,
alibi_nheads=None,
# causal
causal=None,
):
self.sm_scale = sm_scale
self.output_dtype = output_dtype
self.max_seqlens_q = max_seqlens_q
self.max_seqlens_k = max_seqlens_k
self.layout = layout
if cu_seqlens_q is not None or cu_seqlens_k is not None:
assert cu_seqlens_q is not None and cu_seqlens_k is not None
assert layout is None or layout not in [
'bshd', 'bhsd'
], "Varlen only implemented for thd layout"
self.set_varlen_params(cu_seqlens_q, cu_seqlens_k)
quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale]
if any(x is not None for x in quant_params):
p_descale = 1.0 / p_scale if p_scale is not None else None
self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale)
if bias is not None:
self.need_bias(bias, seqlen_q, seqlen_k)
if alibi_slopes is not None:
self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads)
if causal is not None and causal:
self.need_causal()
def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
assert len(cu_seqlens_q) >= 2
assert len(cu_seqlens_q) == len(cu_seqlens_k)
self.num_contexts = len(cu_seqlens_q) - 1
for i in range(0, self.num_contexts):
self.max_seqlens_q = max(
cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(),
self.max_seqlens_q)
self.max_seqlens_k = max(
cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(),
self.max_seqlens_k)
def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale,
p_descale, o_scale):
self.eight_bit = True
self.q_descale = q_descale
self.k_descale = k_descale
self.v_descale = v_descale
self.p_scale = p_scale
self.p_descale = p_descale
self.o_scale = o_scale
self.use_p_scale = (p_scale is not None) and (
p_descale is not None) and (v_descale is not None)
self.eight_bit_kv = ((q_descale is None) and (k_descale is not None)
and (v_descale is not None))
self.eight_bit_dtype_torch = default_eight_bit_dtype_torch
def need_bias(self, bias, seqlen_q, seqlen_k):
assert bias is not None
assert bias.is_cuda
assert bias.dim() == 4
assert bias.shape[0] == 1
assert bias.shape[2:] == (seqlen_q, seqlen_k)
self.bias = bias
def need_alibi(self, alibi_slopes, batch, nheads):
assert alibi_slopes.is_cuda
assert alibi_slopes.dim() == 2
assert alibi_slopes.shape[0] == batch
assert alibi_slopes.shape[1] == nheads
self.alibi_slopes = alibi_slopes
def need_causal(self):
self.causal = True
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()
batch, nheads_q, nheads_k, head_size = get_shape_from_layout(
q, k, self)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
assert self.cu_seqlens_k is not None
assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k)
# TODO: Remove once bias is supported with varlen
assert self.bias is None
assert not self.return_encoded_softmax
else:
assert q.dim() == 4
assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0
assert self.cu_seqlens_q is None and self.cu_seqlens_k is None
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
if self.eight_bit:
if self.eight_bit_kv:
assert (v.dtype == k.dtype
and k.dtype == self.eight_bit_dtype_torch)
assert q.dtype != k.dtype
assert (self.v_descale is not None) and (self.k_descale
is not None)
else:
assert (q.dtype == k.dtype and q.dtype == v.dtype
and q.dtype == self.eight_bit_dtype_torch)
assert (self.q_descale
is not None) and (self.k_descale
is not None) and (self.v_descale
is not None)
if self.use_p_scale:
assert (self.p_scale is not None) and (self.p_descale
is not None)
else:
assert (q.dtype == k.dtype) and (q.dtype == v.dtype)
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen
@triton.jit @triton.jit
...@@ -243,85 +40,40 @@ def max_fn(x, y): ...@@ -243,85 +40,40 @@ def max_fn(x, y):
return tl.math.max(x, y) return tl.math.max(x, y)
# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
@triton.jit @triton.jit
def masked_load(ptrs, offset_first, offset_second, boundary_first, def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
boundary_second): ms = tl.arange(0, m)
if offset_first is not None and offset_second is not None: ns = tl.arange(0, n)
mask = (offset_first[:, None] < boundary_first) & \ return philox_offset + ms[:, None] * stride + ns[None, :]
(offset_second[None, :] < boundary_second)
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_first is not None:
mask = offset_first[:, None] < boundary_first
tensor = tl.load(ptrs, mask=mask, other=0.0)
elif offset_second is not None:
mask = offset_second[None, :] < boundary_second
tensor = tl.load(ptrs, mask=mask, other=0.0)
else:
tensor = tl.load(ptrs)
return tensor
@triton.jit @triton.jit
def compute_alibi_block(alibi_slope, def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
seqlen_q, rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
seqlen_k, stride).to(tl.uint32)
offs_m, # TODO: use tl.randint for better performance
offs_n, return tl.rand(philox_seed, rng_offsets)
transpose=False):
# when seqlen_k and seqlen_q are different we want the diagonal to stick to
# the bottom right of the attention matrix
# for casual mask we want something like this where (1 is kept and 0 is
# masked)
# seqlen_q = 2 and seqlen_k = 5
# 1 1 1 1 0
# 1 1 1 1 1
# seqlen_q = 5 and seqlen_k = 2
# 0 0
# 0 0
# 0 0
# 1 0
# 1 1
# for alibi the diagonal is 0 indicating no penalty for attending to that
# spot and increasing penalty for attending further from the diagonal
# e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5,
# offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False
# 1. offs_m[:,None] = [[0],
# [1],
# 2. offs_m[:,None] + seqlen_k = [[5],
# [6],
# 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3],
# [4],
# 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] =
# [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]]
# 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1],
# [ -4, -3, -2, -1, 0]],
relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q -
offs_n[None, :])
alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block)
if transpose:
return alibi_block.T
else:
return alibi_block
def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): @triton.jit
q_idx = torch.arange(seqlen_q, dtype=torch.int32, def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
k_idx = torch.arange(seqlen_k, dtype=torch.int32, stride)
device="cuda").unsqueeze(0) # (1, N_CTX_K) rng_keep = rng_output > dropout_p
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - return rng_keep
k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(
-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)
@triton.jit @triton.jit
def quant_fp8(x, scale): def load_fn(block_ptr, first, second, pad):
x *= scale if first and second:
x = tl.clamp(x, FP8_MIN, FP8_MAX) tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
return x elif first:
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit @triton.jit
...@@ -330,68 +82,61 @@ def _attn_fwd_inner( ...@@ -330,68 +82,61 @@ def _attn_fwd_inner(
l_i, l_i,
m_i, m_i,
q, q,
k_ptrs, K_block_ptr,
v_ptrs, V_block_ptr,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
actual_seqlen_k, actual_seqlen_k,
actual_seqlen_q, dropout_p,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_sm_ptrs, encoded_softmax_block_ptr,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
alibi_slope, bias_ptr,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr, OFFS_N: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr, PRE_LOAD_V: tl.constexpr,
SHOULD_MASK_STEPS: tl.constexpr, MASK_STEPS: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
USE_PADDED_HEAD: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, PADDED_HEAD: tl.constexpr,
QK_SCALE: tl.constexpr, USE_FP8: tl.constexpr,
IS_EIGHT_BIT_GEMM: tl.constexpr, qk_scale,
USE_P_SCALE: tl.constexpr, p_descale,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
# loop over k, v, and update accumulator # loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N): for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if # For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range. # we load all BLOCK_N. For others, the blocks are all within range.
k_offs_n = start_n + tl.arange(0, k = load_fn(
BLOCK_N) if SHOULD_MASK_STEPS else None K_block_ptr,
k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) PADDED_HEAD,
k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, MASK_STEPS and (n_extra_tokens != 0),
actual_seqlen_k) "zero",
if SHOULD_PRE_LOAD_V: )
# We can use the same offsets as k, just with dims transposed. if PRE_LOAD_V:
v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, v = load_fn(
IS_ACTUAL_BLOCK_DMODEL) V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need # We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n # to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block. # TODO: This can be optimized to only be true for the padded block.
if SHOULD_MASK_STEPS: # noqa: SIM102 if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to # If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size # mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not # a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# is_modulo_mn. last step might get wasted but that is okay. # if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case. # check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], boundary_m = tl.full([BLOCK_M],
...@@ -404,107 +149,112 @@ def _attn_fwd_inner( ...@@ -404,107 +149,112 @@ def _attn_fwd_inner(
causal_boundary = start_n + offs_n_causal causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf")) qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ---- # -- compute qk ----
if IS_EIGHT_BIT_GEMM: qk += tl.dot(q, k)
qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * if USE_FP8:
QK_SCALE) qk *= qk_scale
else: if bias_ptr is not None:
if IS_EIGHT_BIT_KV: bias = load_fn(bias_ptr, False, MASK_STEPS
k = (k * k_descale).to(q.type.element_ty) and (n_extra_tokens != 0), "zero")
qk += (tl.dot(q, k) * QK_SCALE) # While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
if bias_ptrs is not None: # scale factor of log2(e) which we must also multiply the bias with.
bias_offs_n = start_n + tl.arange( qk += bias * 1.44269504089
0, BLOCK_N) if SHOULD_MASK_STEPS else None
bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q,
actual_seqlen_k)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an
# additional scale factor of log2(e) which we must also multiply
# the bias with.
qk += (bias * 1.44269504089)
if alibi_slope is not None:
# Compute the global position of each token within the sequence
global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
global_n_positions = start_n + tl.arange(0, BLOCK_N)
alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q,
actual_seqlen_k,
global_m_positions,
global_n_positions)
qk += (alibi_block * 1.44269504089) # scale factor of log2(e)
# softmax
m_ij = tl.maximum(m_i, tl.max(qk, 1)) m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None] qk = qk - m_ij[:, None]
p = tl.math.exp2(qk) p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout # CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
if SHOULD_RETURN_ENCODED_SOFTMAX: if ENABLE_DROPOUT:
tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) philox_offset = (batch_philox_offset +
start_m * BLOCK_M * actual_seqlen_k + start_n -
BLOCK_N)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p,
-p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator -- # -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij) alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None] acc = acc * alpha[:, None]
if not SHOULD_PRE_LOAD_V: if not PRE_LOAD_V:
v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, v = load_fn(
IS_ACTUAL_BLOCK_DMODEL) V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i # -- update m_i and l_i
l_i = l_i * alpha + l_ij l_i = l_i * alpha + l_ij
# update m_i and l_i # update m_i and l_i
m_i = m_ij m_i = m_ij
if IS_EIGHT_BIT_GEMM: if USE_FP8:
if USE_P_SCALE: p *= p_descale
p = quant_fp8(p, p_scale).to(QUANT_DTYPE)
acc += tl.dot(p, v) acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
else:
# v is in eight_bit but p is not, we want the gemm in p's type V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
acc += tl.dot(p, v.to(p.type.element_ty)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
else: if bias_ptr is not None:
if IS_EIGHT_BIT_KV: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
v = (v * v_descale).to(p.type.element_ty) if RETURN_ENCODED_SOFTMAX:
acc += tl.dot(p.to(v.type.element_ty), v) encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
(0, BLOCK_N))
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
bias_ptrs += BLOCK_N * stride_bn
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += BLOCK_N
return acc, l_i, m_i return acc, l_i, m_i
def get_cdna_autotune_configs(): def get_cdna_autotune_configs():
return [ return [
triton.Config(
{
'BLOCK_M': 256,
'BLOCK_N': 64,
'waves_per_eu': 2,
'PRE_LOAD_V': False
},
num_stages=1,
num_warps=8),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 128, 'BLOCK_N': 128,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 256,
'BLOCK_N': 64, 'BLOCK_N': 128,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=8),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 64, 'BLOCK_N': 64,
'waves_per_eu': 3, 'waves_per_eu': 1,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
...@@ -512,168 +262,141 @@ def get_cdna_autotune_configs(): ...@@ -512,168 +262,141 @@ def get_cdna_autotune_configs():
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 64, 'BLOCK_N': 64,
'waves_per_eu': 1, 'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': True
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
triton.Config( triton.Config(
{ {
'BLOCK_M': 128, 'BLOCK_M': 128,
'BLOCK_N': 32, 'BLOCK_N': 64,
'waves_per_eu': 2, 'waves_per_eu': 3,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=4), num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 64,
'BLOCK_N': 32, 'BLOCK_N': 64,
'waves_per_eu': 4, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=8),
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 32,
'BLOCK_N': 32, 'BLOCK_N': 32,
'waves_per_eu': 2, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
# Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
# triton.Config(
# {
# "BLOCK_M": 16,
# "BLOCK_N": 16,
# "waves_per_eu": 1,
# "PRE_LOAD_V": False,
# },
# num_stages=1,
# num_warps=4,
# ),
], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
def get_rdna_autotune_configs():
return [
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 32,
'waves_per_eu': 4, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
triton.Config( triton.Config(
{ {
'BLOCK_M': 32, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 32,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
triton.Config( triton.Config(
{ {
'BLOCK_M': 16, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 16,
'waves_per_eu': 4, 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
triton.Config( triton.Config(
{ {
'BLOCK_M': 16, 'BLOCK_M': 32,
'BLOCK_N': 16, 'BLOCK_N': 16,
'waves_per_eu': 2, 'waves_per_eu': 2,
'SHOULD_PRE_LOAD_V': False, 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2
}, },
num_stages=1, num_stages=1,
num_warps=2), num_warps=2),
# Fall-back config. # Fails in AccelerateAMDMatmul (Triton) assert when using FP8:
triton.Config( # triton.Config(
{ # {
'BLOCK_M': 16, # 'BLOCK_M': 16,
'BLOCK_N': 16, # 'BLOCK_N': 16,
'waves_per_eu': 1, # 'waves_per_eu': 4,
'SHOULD_PRE_LOAD_V': False, # 'PRE_LOAD_V': False
'GRID_CU_MULTIP': 2 # },
}, # num_stages=1,
num_stages=1, # num_warps=2),
num_warps=2), # triton.Config(
], [ # {
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', # 'BLOCK_M': 16,
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' # 'BLOCK_N': 16,
] # 'waves_per_eu': 2,
# 'PRE_LOAD_V': False
# },
def get_general_autotune_configs(): # num_stages=1,
return [ # num_warps=2),
triton.Config( # # Fall-back config.
{ # triton.Config(
'BLOCK_M': 128, # {
'BLOCK_N': 128, # 'BLOCK_M': 16,
'SHOULD_PRE_LOAD_V': False, # 'BLOCK_N': 16,
'GRID_CU_MULTIP': 2 # 'waves_per_eu': 1,
}, # 'PRE_LOAD_V': False
num_stages=1, # },
num_warps=4), # num_stages=1,
triton.Config( # num_warps=2),
{ ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8']
'BLOCK_M': 128,
'BLOCK_N': 64,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
triton.Config(
{
'BLOCK_M': 128,
'BLOCK_N': 32,
'SHOULD_PRE_LOAD_V': False,
'GRID_CU_MULTIP': 2
},
num_stages=1,
num_warps=4),
], [
'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K',
'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'
]
def has_cdna_target():
ROCM_CDNA_TARGETS = ["gfx942", "gfx90a", "gfx908"]
return triton.runtime.driver.active.get_current_target(
).arch in ROCM_CDNA_TARGETS
def is_rocm_cdna():
return current_platform.is_rocm() and has_cdna_target()
def get_autotune_configs(): def get_autotune_configs():
if is_rocm_cdna(): if on_gfx1x():
return get_cdna_autotune_configs()
elif current_platform.is_rocm():
return get_rdna_autotune_configs() return get_rdna_autotune_configs()
else: else:
return get_general_autotune_configs() return get_cdna_autotune_configs()
autotune_configs, autotune_keys = get_autotune_configs() autotune_configs, autotune_keys = get_autotune_configs()
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.autotune( @triton.autotune(
configs=autotune_configs, configs=autotune_configs,
key=autotune_keys, key=autotune_keys,
use_cuda_graph=True,
) )
@triton.jit @triton.jit
def attn_fwd( def attn_fwd(
...@@ -681,7 +404,13 @@ def attn_fwd( ...@@ -681,7 +404,13 @@ def attn_fwd(
K, K,
V, V,
bias, bias,
SM_SCALE: tl.constexpr, sm_scale,
q_scale,
k_scale,
v_scale,
p_scale,
p_descale,
o_descale,
L, L,
Out, Out,
stride_qz: tl.int64, stride_qz: tl.int64,
...@@ -704,70 +433,44 @@ def attn_fwd( ...@@ -704,70 +433,44 @@ def attn_fwd(
stride_bh: tl.int64, stride_bh: tl.int64,
stride_bm: tl.int64, stride_bm: tl.int64,
stride_bn: tl.int64, stride_bn: tl.int64,
stride_az: tl.int64,
stride_ah: tl.int64,
q_descale_ptr,
k_descale_ptr,
p_scale_ptr,
p_descale_ptr,
o_descale_ptr,
v_descale_ptr,
q_descale_has_singleton: tl.constexpr,
k_descale_has_singleton: tl.constexpr,
p_descale_has_singleton: tl.constexpr,
v_descale_has_singleton: tl.constexpr,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
dropout_p,
philox_seed, philox_seed,
NUM_CU: tl.constexpr,
GRID_CU_MULTIP: tl.constexpr,
B: tl.constexpr,
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
alibi_slopes,
HQ: tl.constexpr, HQ: tl.constexpr,
HK: tl.constexpr, HK: tl.constexpr,
IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
USE_FP8: tl.constexpr,
USE_FP8_OUT: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
SHOULD_PRE_LOAD_V: tl.constexpr, PRE_LOAD_V: tl.constexpr,
USE_BIAS: tl.constexpr, BIAS_TYPE: tl.constexpr,
SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, ENABLE_DROPOUT: tl.constexpr,
USE_ALIBI: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
IS_EIGHT_BIT: tl.constexpr, FP8_MIN: tl.constexpr = float8_info.min,
USE_P_SCALE: tl.constexpr, FP8_MAX: tl.constexpr = float8_info.max,
IS_EIGHT_BIT_KV: tl.constexpr,
QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton,
): ):
start_m = tl.program_id(0)
if o_descale_ptr is not None: off_h_q = tl.program_id(1)
o_descale = tl.load(o_descale_ptr) off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
start_m: tl.int64 = tl.program_id(0) offs_n = tl.arange(0, BLOCK_N)
off_h_q: tl.int64 = tl.program_id(1)
off_z: tl.int64 = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64)
offs_n = tl.arange(0, BLOCK_N).to(tl.int64)
offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64)
# as we can't have return statements inside while loop in Triton
continue_condition = True
if VARLEN: if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be # We have a one-size-fits-all grid in id(0). Some seqlens might be too
# too small for all start_m so for those we return early. # small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q: if start_m * BLOCK_M > seqlen_q:
continue_condition = False return
# return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
...@@ -777,186 +480,161 @@ def attn_fwd( ...@@ -777,186 +480,161 @@ def attn_fwd(
seqlen_q = MAX_SEQLENS_Q seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K seqlen_k = MAX_SEQLENS_K
if continue_condition: # Now we compute whether we need to exit early due to causal masking.
# Now we compute whether we need to exit early due to causal # This is because for seqlen_q > seqlen_k, M rows of the attn scores
# masking. This is because for seqlen_q > seqlen_k, M rows of the # are completely masked, resulting in 0s written to the output, and
# attn scores are completely masked, resulting in 0s written to the # inf written to LSE. We don't need to do any GEMMs in this case.
# output, and inf written to LSE. We don't need to do any GEMMs in # This block of code determines what N is, and if this WG is operating
# this case. This block of code determines what N is, and if this # on those M rows.
# WG is operating on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N) n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if (IS_CAUSAL): if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which # If seqlen_q != seqlen_k, attn scores are rectangular which means
# means the causal mask boundary is bottom right aligned, and # the causal mask boundary is bottom right aligned, and ends at either
# ends at either the top edge (seqlen_q < seqlen_k) or left # the top edge (seqlen_q < seqlen_k) or left edge.
# edge. This captures the decrease in n_blocks if we have a # This captures the decrease in n_blocks if we have a rectangular attn
# rectangular attn matrix # matrix
n_blocks_seqlen = cdiv_fn( n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
# This is what adjusts the block_max for the current WG, only # This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
# n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen) n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this # If we have no blocks after adjusting for seqlen deltas, this WG is
# WG is part of the blocks that are all 0. We exit early. # part of the blocks that are all 0. We exit early.
if n_blocks <= 0: if n_blocks <= 0:
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
cu_seqlens_q_start * stride_om) off_h_q * stride_oh)
o_ptrs = (o_offset + offs_m[:, None] * stride_om + O_block_ptr = tl.make_block_ptr(
offs_d[None, :] * stride_on) base=Out + o_offset,
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) shape=(seqlen_q, BLOCK_DMODEL),
o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( strides=(stride_om, stride_on),
[BLOCK_M, BLOCK_DMODEL]) offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result # We still need to write 0s to the result
tl.store(o_ptrs, acc, mask=o_ptrs_mask) # tl.store(O_block_ptr,
# The tensor allocated for L is based on MAX_SEQLENS_Q as # acc.to(Out.type.element_ty), boundary_check=(0,1))
# that is statically known. # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + # + offs_m
off_h_q * MAX_SEQLENS_Q + offs_m)
# We store inf to LSE, not -inf because in the bwd pass, # We store inf to LSE, not -inf because in the bwd pass,
# we subtract this from qk which makes it -inf, such that # we subtract this
# exp(qk - inf) = 0 for these masked blocks. # from qk which makes it -inf, such that exp(qk - inf) = 0
l_value = tl.full([BLOCK_M], # for these masked blocks.
value=float("inf"), # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
dtype=tl.float32) # tl.store(l_ptrs, l)
l_ptrs_mask = offs_m < MAX_SEQLENS_Q # TODO: Should dropout and return encoded softmax be handled here?
tl.store(l_ptrs, l_value, mask=l_ptrs_mask) return
# TODO: Should dropout and return encoded softmax be
# handled here too?
continue_condition = False
# return
if continue_condition:
# If MQA / GQA, set the K and V head offsets appropriately. # If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0 n_extra_tokens = 0
if seqlen_k < BLOCK_N: if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N: elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N n_extra_tokens = seqlen_k % BLOCK_N
USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
!= BLOCK_DMODEL)
# Compute pointers for all the tensors used in this kernel. # Compute pointers for all the tensors used in this kernel.
q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + q_offset = (off_z * stride_qz + off_h_q * stride_qh +
cu_seqlens_q_start * stride_qm) cu_seqlens_q_start * stride_qm)
q_ptrs = (q_offset + offs_m[:, None] * stride_qm + Q_block_ptr = tl.make_block_ptr(
offs_d[None, :] * stride_qk) base=Q + q_offset,
k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
cu_seqlens_k_start * stride_kn) cu_seqlens_k_start * stride_kn)
k_ptrs = (k_offset + offs_d[:, None] * stride_kk + K_block_ptr = tl.make_block_ptr(
offs_n[None, :] * stride_kn) base=K + k_offset,
v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
cu_seqlens_k_start * stride_vk) cu_seqlens_k_start * stride_vk)
v_ptrs = (v_offset + offs_n[:, None] * stride_vk + V_block_ptr = tl.make_block_ptr(
offs_d[None, :] * stride_vn) base=V + v_offset,
# Compute pointers for all scale tensors used in this kernel. shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( offsets=(0, 0),
not IS_EIGHT_BIT_KV) block_shape=(BLOCK_N, BLOCK_DMODEL),
if IS_EIGHT_BIT: order=(1, 0),
if k_descale_has_singleton: )
k_descale_ptrs = k_descale_ptr if BIAS_TYPE != 0:
else: bias_ptr = tl.make_block_ptr(
k_descale_ptrs = k_descale_ptr + off_h_k base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
if v_descale_has_singleton: strides=(stride_bm, stride_bn),
v_descale_ptrs = v_descale_ptr offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else: else:
v_descale_ptrs = v_descale_ptr + off_h_k bias_ptr = None
if ENABLE_DROPOUT:
if not IS_EIGHT_BIT_KV: batch_philox_offset = philox_offset_base \
if q_descale_has_singleton: + (off_z * HQ + off_h_q) \
q_descale_ptrs = q_descale_ptr * seqlen_q * seqlen_k
else:
q_descale_ptrs = q_descale_ptr + off_h_q
if USE_P_SCALE:
if p_descale_has_singleton:
p_scale_ptrs = p_scale_ptr
p_descale_ptrs = p_descale_ptr
else:
p_scale_ptrs = p_scale_ptr + off_h_q
p_descale_ptrs = p_descale_ptr + off_h_q
if USE_BIAS:
bias_offset = off_h_q * stride_bh
bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm +
offs_n[None, :] * stride_bn)
else: else:
bias_ptrs = None
if USE_ALIBI:
a_offset = off_z * stride_az + off_h_q * stride_ah
alibi_slope = tl.load(alibi_slopes + a_offset)
else:
alibi_slope = None
batch_philox_offset = 0 batch_philox_offset = 0
# We can ask to return the dropout mask without doing any # We can ask to return the dropout mask without actually doing any dropout.
# dropout. In this case, we return an invalid pointer so # In this case, we return an invalid pointer so indicate the mask is not i
# indicate the mask is not valid. # valid.
if SHOULD_RETURN_ENCODED_SOFTMAX: # TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
encoded_sm_base = (encoded_softmax + if RETURN_ENCODED_SOFTMAX:
off_h_q * seqlen_q * seqlen_k) encoded_softmax_block_ptr = tl.make_block_ptr(
encoded_sm_ptrs = (encoded_sm_base + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
offs_m[:, None] * seqlen_k + shape=(seqlen_q, seqlen_k),
offs_n[None, :]) strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else: else:
encoded_sm_ptrs = None encoded_softmax_block_ptr = 0
# initialize pointer to m and l # initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do # scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# not have native e^x support in HW. # have native e^x support in HW.
QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks. # Q is loaded once at the beginning and shared by all N blocks.
q_ptrs_mask = offs_m[:, None] < seqlen_q q = load_fn(Q_block_ptr, True, padded_head, "zero")
if USE_PADDED_HEAD: if not USE_FP8:
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
< IS_ACTUAL_BLOCK_DMODEL) acc_scale = 1.0
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
if IS_EIGHT_BIT:
k_descale = tl.load(k_descale_ptrs)
v_descale = tl.load(v_descale_ptrs)
q_descale = None if IS_EIGHT_BIT_KV else tl.load(
q_descale_ptrs)
if USE_P_SCALE:
p_scale = tl.load(p_scale_ptrs)
p_descale = tl.load(p_descale_ptrs)
else: else:
p_scale = None qk_scale *= q_scale * k_scale
p_descale = None acc_scale = p_scale * v_scale
else:
q_descale = None
k_descale = None
v_descale = None
p_scale = None
p_descale = None
# Here we compute how many full and masked blocks we have. # Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0 padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL: if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked # There are always at least BLOCK_M // BLOCK_N masked blocks.
# blocks. Additionally there might be one more due to # Additionally there might be one more due to dissimilar seqlens.
# dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else: else:
# Padding on Q does not need to be masked in the FA loop. # Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an # if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# additional block. In this case we might exceed n_blocks so # block. In this case we might exceed n_blocks so pick the min.
# pick the min.
masked_blocks = min(masked_blocks, n_blocks) masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks n_full_blocks = n_blocks - masked_blocks
block_min = 0 block_min = 0
block_max = n_blocks * BLOCK_N block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false # Compute for full blocks. Here we set causal to false regardless of its
# regardless of its actual value because there is no masking. # value because there is no masking. Similarly we do not need padding.
# Similarly we do not need padding.
if n_full_blocks > 0: if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i = _attn_fwd_inner(
...@@ -964,29 +642,21 @@ def attn_fwd( ...@@ -964,29 +642,21 @@ def attn_fwd(
l_i, l_i,
m_i, m_i,
q, q,
k_ptrs, K_block_ptr,
v_ptrs, V_block_ptr,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
seqlen_k, seqlen_k,
seqlen_q, dropout_p,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_sm_ptrs, encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_min,
block_max, block_max,
0, 0,
0, 0,
0, 0,
alibi_slope, bias_ptr,
q_descale,
k_descale,
v_descale,
p_scale,
# IS_CAUSAL, .... # IS_CAUSAL, ....
False, False,
BLOCK_M, BLOCK_M,
...@@ -994,105 +664,85 @@ def attn_fwd( ...@@ -994,105 +664,85 @@ def attn_fwd(
BLOCK_N, BLOCK_N,
offs_m, offs_m,
offs_n, offs_n,
# _, SHOULD_MASK_STEPS, ... # _, MASK_STEPS, ...
SHOULD_PRE_LOAD_V, PRE_LOAD_V,
False, False,
SHOULD_RETURN_ENCODED_SOFTMAX, ENABLE_DROPOUT,
USE_PADDED_HEAD, RETURN_ENCODED_SOFTMAX,
IS_ACTUAL_BLOCK_DMODEL, padded_head,
QK_SCALE, USE_FP8,
IS_EIGHT_BIT_GEMM, qk_scale,
USE_P_SCALE, p_descale,
IS_EIGHT_BIT_KV, )
QUANT_DTYPE)
block_min = block_max block_min = block_max
block_max = n_blocks * BLOCK_N block_max = n_blocks * BLOCK_N
tl.debug_barrier() tl.debug_barrier()
# Remaining blocks, if any, are full / not masked. # Remaining blocks, if any, are full / not masked.
if (masked_blocks > 0): if masked_blocks > 0:
if IS_CAUSAL: offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
offs_n_causal = offs_n + (seqlen_q - seqlen_k) K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
else: V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
offs_n_causal = 0 if bias_ptr is not None:
k_ptrs += n_full_blocks * BLOCK_N * stride_kn bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
v_ptrs += n_full_blocks * BLOCK_N * stride_vk if RETURN_ENCODED_SOFTMAX:
if USE_BIAS: encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
bias_ptrs += n_full_blocks * BLOCK_N * stride_bn (0, n_full_blocks))
if SHOULD_RETURN_ENCODED_SOFTMAX:
encoded_sm_ptrs += n_full_blocks * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner( acc, l_i, m_i = _attn_fwd_inner(
acc, acc,
l_i, l_i,
m_i, m_i,
q, q,
k_ptrs, K_block_ptr,
v_ptrs, V_block_ptr,
bias_ptrs,
stride_kn,
stride_vk,
stride_bn,
start_m, start_m,
seqlen_k, seqlen_k,
seqlen_q, dropout_p,
philox_seed, philox_seed,
batch_philox_offset, batch_philox_offset,
encoded_sm_ptrs, encoded_softmax_block_ptr,
block_min, block_min,
block_max, block_max,
offs_n_causal, offs_n_causal,
masked_blocks, masked_blocks,
n_extra_tokens, n_extra_tokens,
alibi_slope, bias_ptr,
q_descale,
k_descale,
v_descale,
p_scale,
IS_CAUSAL, IS_CAUSAL,
BLOCK_M, BLOCK_M,
BLOCK_DMODEL, BLOCK_DMODEL,
BLOCK_N, BLOCK_N,
offs_m, offs_m,
offs_n, offs_n,
# _, SHOULD_MASK_STEPS, ... # _, MASK_STEPS, ...
SHOULD_PRE_LOAD_V, PRE_LOAD_V,
True, True,
SHOULD_RETURN_ENCODED_SOFTMAX, ENABLE_DROPOUT,
USE_PADDED_HEAD, RETURN_ENCODED_SOFTMAX,
IS_ACTUAL_BLOCK_DMODEL, padded_head,
QK_SCALE, USE_FP8,
IS_EIGHT_BIT_GEMM, qk_scale,
USE_P_SCALE, p_descale,
IS_EIGHT_BIT_KV, )
QUANT_DTYPE)
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV:
if USE_P_SCALE:
acc *= p_descale
acc *= v_descale
# epilogue # epilogue
# This helps the compiler do Newton Raphson on l_i vs on acc
# which is much larger. if USE_FP8:
l_recip = 1 / l_i[:, None] acc *= acc_scale
acc = acc * l_recip acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
# If seqlen_q > seqlen_k but the delta is not a multiple of acc = acc / (1 - dropout_p)
# BLOCK_M, then we have one block with a row of all NaNs which # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# come from computing softmax over a row of all # then we have one block with a row of all NaNs which come from computing
# -infs (-inf - inf = NaN). We check for that here and store 0s # softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# where there are NaNs as these rows should've been zeroed out. # and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k causal_start_idx = seqlen_q - seqlen_k
if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 if USE_FP8_OUT:
if o_descale_ptr is not None: acc *= o_descale
acc = quant_fp8(acc, o_descale) acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
acc = acc.to(Out.type.element_ty) acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102 if IS_CAUSAL: # noqa: SIM102
if (causal_start_idx > start_m_idx if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
and causal_start_idx < end_m_idx):
out_mask_boundary = tl.full((BLOCK_DMODEL, ), out_mask_boundary = tl.full((BLOCK_DMODEL, ),
causal_start_idx, causal_start_idx,
dtype=tl.int32) dtype=tl.int32)
...@@ -1100,275 +750,229 @@ def attn_fwd( ...@@ -1100,275 +750,229 @@ def attn_fwd(
out_ptrs_mask = (mask_m_offsets[:, None] out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :]) >= out_mask_boundary[None, :])
z = tl.zeros((1, ), tl.float32) z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc, acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
z.to(acc.type.element_ty))
# write back LSE # write back LSE
l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
off_h_q * MAX_SEQLENS_Q + offs_m) # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# If seqlen_q not multiple of BLOCK_M, we need to mask out the # few rows. This is only true for the last M block. For others,
# last few rows. This is only true for the last M block. # overflow_size will be -ve
# For others, overflow_size will be -ve # overflow_size = end_m_idx - seqlen_q
overflow_size = end_m_idx - seqlen_q # if overflow_size > 0:
if overflow_size > 0: # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
boundary = tl.full((BLOCK_M, ), # # This is a > check because mask being 0 blocks the store.
BLOCK_M - overflow_size, # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
dtype=tl.int32) # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary # else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) # tl.store(l_ptrs, m_i + tl.math.log2(l_i))
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O # write back O
o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
cu_seqlens_q_start * stride_om) off_h_q * stride_oh)
o_ptrs = (o_offset + offs_m[:, None] * stride_om + O_block_ptr = tl.make_block_ptr(
offs_d[None, :] * stride_on) base=Out + o_offset,
o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
if overflow_size > 0: strides=(stride_om, stride_on),
o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) offsets=(start_m * BLOCK_M, 0),
if USE_PADDED_HEAD: block_shape=(BLOCK_M, BLOCK_DMODEL),
o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] order=(1, 0),
< IS_ACTUAL_BLOCK_DMODEL) )
tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) # Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
def get_shape_from_layout(q, k, metadata): tl.store(O_block_ptr, acc, boundary_check=(0, 1))
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout."
if metadata.layout == 'thd': def check_args(
nheads_q, nheads_k = q.shape[1], k.shape[1] q,
head_size = q.shape[-1] k,
batch = metadata.num_contexts v,
elif metadata.layout == 'bhsd': o,
batch, nheads_q, _, head_size = q.shape varlen=True,
nheads_k = k.shape[1] max_seqlens=None,
elif metadata.layout == 'bshd': cu_seqlens_q=None,
batch, _, nheads_q, head_size = q.shape cu_seqlens_k=None,
nheads_k = k.shape[2] ):
return batch, nheads_q, nheads_k, head_size assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
def get_strides_from_layout(q, k, v, o, metadata): total_q, nheads_q, head_size = q.shape
assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
STRIDE_PERMUTATIONS = { assert cu_seqlens_k is not None
'thd': (None, 1, 0, 2), assert len(cu_seqlens_q) == len(cu_seqlens_k)
'bhsd': (0, 1, 2, 3), else:
'bshd': (0, 2, 1, 3), assert q.dim() == 4
} batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
perm = STRIDE_PERMUTATIONS[metadata.layout] assert max_seqlens > 0
stride = lambda x, p: (0 if p is None else x.stride(p)) assert k.shape == v.shape
strides = lambda x: (stride(x, p) for p in perm) assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
return tuple(strides(x) for x in [q, k, v, o]) assert q.dtype == k.dtype and q.dtype == v.dtype
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function): class _attention(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, o, metadata: MetaData): def forward(
# NOTE: a large bias tensor leads to overflow during pointer arithmetic ctx,
if (metadata.bias is not None):
assert (metadata.bias.numel() < 2**31)
if o is None:
if metadata.eight_bit:
o = torch.empty_like(
q, q,
dtype=metadata.output_dtype if metadata.output_dtype k,
is not None else metadata.eight_bit_dtype_torch) v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
fp8_scales=None,
fp8_out_scale=None,
):
if fp8_scales is not None:
use_fp8 = True
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
float8 = current_platform.fp8_dtype()
def check_and_convert(t, scale):
if t.dtype != float8:
descale = 1.0 / scale
ts = (t * descale).clamp(min=float8_info.min,
max=float8_info.max)
return ts.to(float8)
else: else:
o = torch.empty_like(q, dtype=q.dtype) return t
metadata.check_args(q, k, v, o) q = check_and_convert(q, q_scale)
k = check_and_convert(k, k_scale)
v = check_and_convert(v, v_scale)
else:
use_fp8 = False
q_scale = k_scale = v_scale = p_scale = 1.0
batch, nheads_q, nheads_k, head_size = get_shape_from_layout( if o is None:
q, k, metadata) o = torch.empty_like(q, dtype=v.dtype)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(
q, k, v, o, metadata) check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32. # Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length() unpadded_head_dims = {32, 64, 128, 256}
# Smallest head_dim supported is 16. If smaller, the tile in the if head_size not in unpadded_head_dims:
# kernel is padded - there is no padding in memory for any dims. padded_d_model = None
padded_d_model = max(padded_d_model, 16) for i in unpadded_head_dims:
if i > head_size:
# encoded_softmax is used to validate dropout behavior vs the padded_d_model = i
# PyTorch SDPA math backend reference. We zero this out to give a break
# consistent starting point and then populate it with the output of assert padded_d_model is not None
# softmax with the sign bit set according to the dropout mask.
# The resulting return allows this mask to be fed into the reference
# implementation for testing only. This return holds no useful output
# aside from debugging.
if metadata.return_encoded_softmax:
encoded_softmax = torch.zeros(
(q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
device=q.device,
dtype=torch.float32)
else: else:
encoded_softmax = None padded_d_model = head_size
M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), grid = lambda META: (
device=q.device, triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
dtype=torch.float32) nheads_q,
batch,
)
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing. # Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52 philox_seed = 0x1BF52
philox_offset = 0x1D4B42 philox_offset = 0x1D4B42
if metadata.bias is not None: if bias is not None:
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), bias_strides = (
metadata.bias.stride(2), metadata.bias.stride(3)) bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else: else:
bias_strides = (0, 0, 0, 0) bias_strides = (0, 0, 0, 0)
if metadata.alibi_slopes is not None: p_descale = 1.0 / p_scale
alibi_strides = (metadata.alibi_slopes.stride(0), o_descale = 1.0 / fp8_out_scale.item(
metadata.alibi_slopes.stride(1)) ) if fp8_out_scale is not None else 1.0
else:
alibi_strides = (0, 0)
if metadata.eight_bit: arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q
q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k
metadata.q_descale, metadata.k_descale, metadata.p_scale,
metadata.p_descale, metadata.v_descale, metadata.o_scale)
o_descale = 1.0 / o_scale if o_scale is not None else None
else:
q_descale = k_descale = p_scale = None
p_descale = v_descale = o_descale = None
# number of compute units available
NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[
'BLOCK_M']), nheads_q, batch)
attn_fwd[grid]( attn_fwd[grid](
q, q,
k, k,
v, v,
metadata.bias, bias,
metadata.sm_scale, sm_scale,
M, q_scale,
k_scale,
v_scale,
p_scale,
p_descale,
o_descale,
None,
o, o,
*q_strides, *q_strides,
*k_strides, *k_strides,
*v_strides, *v_strides,
*o_strides, *o_strides,
*bias_strides, *bias_strides,
*alibi_strides, cu_seqlens_q,
q_descale, cu_seqlens_k,
k_descale, dropout_p=0.0,
p_scale,
p_descale,
o_descale,
v_descale,
q_descale.numel() == 1 if q_descale is not None else False,
k_descale.numel() == 1 if k_descale is not None else False,
p_descale.numel() == 1 if p_descale is not None else False,
v_descale.numel() == 1 if v_descale is not None else False,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
philox_seed=philox_seed, philox_seed=philox_seed,
philox_offset_base=philox_offset, philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax, encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HQ=nheads_q,
HK=nheads_k, HK=nheads_k,
IS_ACTUAL_BLOCK_DMODEL=head_size, ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_Q=arg_max_seqlens_q,
MAX_SEQLENS_K=metadata.max_seqlens_k, MAX_SEQLENS_K=arg_max_seqlens_k,
IS_CAUSAL=metadata.causal, IS_CAUSAL=causal,
VARLEN=metadata.varlen, VARLEN=True,
BLOCK_DMODEL=padded_d_model, BLOCK_DMODEL=padded_d_model,
USE_BIAS=metadata.bias is not None, BIAS_TYPE=0 if bias is None else 1,
USE_ALIBI=metadata.alibi_slopes is not None, ENABLE_DROPOUT=False,
SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, RETURN_ENCODED_SOFTMAX=False,
IS_EIGHT_BIT=metadata.eight_bit, USE_FP8=use_fp8,
USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, USE_FP8_OUT=fp8_out_scale is not None,
IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, )
NUM_CU=NUM_CU,
B=batch,
QUANT_DTYPE=metadata.eight_bit_dtype_triton)
ctx.grid = grid ctx.grid = grid
ctx.sm_scale = metadata.sm_scale ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size ctx.BLOCK_DMODEL = head_size
ctx.causal = metadata.causal ctx.causal = causal
ctx.alibi_slopes = metadata.alibi_slopes ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = metadata.return_encoded_softmax ctx.return_encoded_softmax = False
return o, encoded_softmax return o, encoded_softmax
triton_attention_rocm = _attention.apply triton_attention = _attention.apply
def scale_fp8(t, scale=None):
t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]),
scale)
return t_scaled.reshape(t.shape), scale_out
def maybe_quantize_fp8(t, scale):
eight_bit_dtype = current_platform.fp8_dtype()
if t.dtype != eight_bit_dtype:
t, _ = scale_fp8(t, scale)
return t
def check_and_maybe_quantize_qkv(q, k, v, fp8_scales):
(q_scale, k_scale, v_scale, p_scale) = fp8_scales
q = maybe_quantize_fp8(q, q_scale)
k = maybe_quantize_fp8(k, k_scale)
v = maybe_quantize_fp8(v, v_scale)
return q, k, v
# query - [num_tokens, num_heads, head_size]
# key - [num_tokens, num_kv_heads, head_size]
# value - [num_tokens, num_kv_heads, head_size
# output - [num_tokens, num_heads, head_size]
def triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlens_q: int,
max_seqlens_k: int,
causal: bool = False,
sm_scale: float = 1.0,
bias: Optional[torch.Tensor] = None,
fp8_scales: Optional[tuple[float, ...]] = None,
input_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if fp8_scales is not None:
q_descale, k_descale, v_descale, p_scale = fp8_scales
else:
q_descale = k_descale = v_descale = p_scale = None
attn_metadata = MetaData(sm_scale=sm_scale,
max_seqlens_q=max_seqlens_q,
max_seqlens_k=max_seqlens_k,
causal=causal,
bias=bias,
output_dtype=q.dtype,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
p_scale=p_scale,
o_scale=input_scale)
if fp8_scales is not None:
q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales)
return triton_attention_rocm(q, k, v, o, attn_metadata)
...@@ -98,6 +98,12 @@ def with_amdsmi_context(fn): ...@@ -98,6 +98,12 @@ def with_amdsmi_context(fn):
return wrapper return wrapper
@cache
def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@cache @cache
def on_mi250_mi300() -> bool: def on_mi250_mi300() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment