Commit 86862cfd authored by Tri Dao's avatar Tri Dao
Browse files

Implement attention bias for Triton version

parent 470010f5
""" """
*Experimental* implementation of FlashAttention in Triton.
We use the FlashAttention implementation from Phil Tillet a starting point. We use the FlashAttention implementation from Phil Tillet a starting point.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
...@@ -7,6 +9,7 @@ Changes: ...@@ -7,6 +9,7 @@ Changes:
- Implement both self-attention and cross-attention. - Implement both self-attention and cross-attention.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
- Support attention bias.
- Speed up the forward pass a bit, and only store the LSE instead of m and l. - Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling. - Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
...@@ -31,6 +34,8 @@ import math ...@@ -31,6 +34,8 @@ import math
import torch import torch
from einops import rearrange, repeat
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -41,7 +46,7 @@ import triton.language as tl ...@@ -41,7 +46,7 @@ import triton.language as tl
# This config has a race condition when EVEN_M == False, disabling it for now. # This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
], ],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
) )
@triton.heuristics( @triton.heuristics(
{ {
...@@ -52,15 +57,17 @@ import triton.language as tl ...@@ -52,15 +57,17 @@ import triton.language as tl
) )
@triton.jit @triton.jit
def _fwd_kernel( def _fwd_kernel(
Q, K, V, Out, Q, K, V, Bias, Out,
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale, softmax_scale,
stride_qb, stride_qh, stride_qm, stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn, stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn, stride_vb, stride_vh, stride_vn,
stride_bb, stride_bh, stride_bm,
stride_ob, stride_oh, stride_om, stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
...@@ -84,6 +91,10 @@ def _fwd_kernel( ...@@ -84,6 +91,10 @@ def _fwd_kernel(
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l # initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
...@@ -123,13 +134,34 @@ def _fwd_kernel( ...@@ -123,13 +134,34 @@ def _fwd_kernel(
other=0.0) other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True) qk += tl.dot(q, k, trans_b=True)
if not EVEN_N: # Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
if BIAS_TYPE != 'none':
if BIAS_TYPE == 'vector':
if EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == 'matrix':
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n,
mask=(offs_m[:, None] < seqlen_q)
& ((start_n + offs_n)[None, :] < seqlen_k),
other=0.0).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
qk = qk * softmax_scale + bias
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
# Slightly faster to multiply the softmax_scale here since the compiler can then
# fuse the mult and add into an fma instruction.
p = tl.exp(qk * softmax_scale - m_ij[:, None]) p = tl.exp(qk * softmax_scale - m_ij[:, None])
l_ij = tl.sum(p, 1) l_ij = tl.sum(p, 1)
...@@ -218,12 +250,15 @@ def _bwd_preprocess_do_o_dot( ...@@ -218,12 +250,15 @@ def _bwd_preprocess_do_o_dot(
@triton.jit @triton.jit
def _bwd_kernel_one_col_block( def _bwd_kernel_one_col_block(
start_n, start_n,
Q, K, V, softmax_scale, Q, K, V, Bias,
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm,
stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, headdim, seqlen_q, seqlen_k, headdim,
ATOMIC_ADD: tl.constexpr, ATOMIC_ADD: tl.constexpr,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
...@@ -242,6 +277,10 @@ def _bwd_kernel_one_col_block( ...@@ -242,6 +277,10 @@ def _bwd_kernel_one_col_block(
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
# initialize dv and dk # initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
...@@ -286,12 +325,31 @@ def _bwd_kernel_one_col_block( ...@@ -286,12 +325,31 @@ def _bwd_kernel_one_col_block(
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
if BIAS_TYPE != 'none':
if BIAS_TYPE == 'vector':
if EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)
else:
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == 'matrix':
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)
else:
bias = tl.load(b_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_n[None, :] < seqlen_k),
other=0.0).to(tl.float32)
qk = qk * softmax_scale + bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64. # Also wrong for headdim=64.
if not (EVEN_M & EVEN_HEADDIM): if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier() tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr) lse_i = tl.load(LSE + offs_m_curr)
if BIAS_TYPE == 'none':
p = tl.exp(qk * softmax_scale - lse_i[:, None]) p = tl.exp(qk * softmax_scale - lse_i[:, None])
else:
p = tl.exp(qk - lse_i[:, None])
# compute dv # compute dv
# [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
...@@ -368,6 +426,8 @@ def _bwd_kernel_one_col_block( ...@@ -368,6 +426,8 @@ def _bwd_kernel_one_col_block(
dq_ptrs += BLOCK_M * stride_dqm dq_ptrs += BLOCK_M * stride_dqm
q_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_dom do_ptrs += BLOCK_M * stride_dom
if BIAS_TYPE == 'matrix':
b_ptrs += BLOCK_M * stride_bm
# write-back # write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
...@@ -392,6 +452,7 @@ def _bwd_kernel_one_col_block( ...@@ -392,6 +452,7 @@ def _bwd_kernel_one_col_block(
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
...@@ -403,7 +464,7 @@ def init_to_zero(name): ...@@ -403,7 +464,7 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
], ],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
) )
@triton.heuristics( @triton.heuristics(
{ {
...@@ -414,19 +475,21 @@ def init_to_zero(name): ...@@ -414,19 +475,21 @@ def init_to_zero(name):
) )
@triton.jit @triton.jit
def _bwd_kernel( def _bwd_kernel(
Q, K, V, Q, K, V, Bias,
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
softmax_scale, softmax_scale,
stride_qb, stride_qh, stride_qm, stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn, stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn, stride_vb, stride_vh, stride_vn,
stride_bb, stride_bh, stride_bm,
stride_dob, stride_doh, stride_dom, stride_dob, stride_doh, stride_dom,
stride_dqb, stride_dqh, stride_dqm, stride_dqb, stride_dqh, stride_dqm,
stride_dkb, stride_dkh, stride_dkn, stride_dkb, stride_dkh, stride_dkn,
stride_dvb, stride_dvh, stride_dvn, stride_dvb, stride_dvh, stride_dvn,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr,
...@@ -444,6 +507,8 @@ def _bwd_kernel( ...@@ -444,6 +507,8 @@ def _bwd_kernel(
DQ += off_b * stride_dqb + off_h * stride_dqh DQ += off_b * stride_dqb + off_h * stride_dqh
DK += off_b * stride_dkb + off_h * stride_dkh DK += off_b * stride_dkb + off_h * stride_dkh
DV += off_b * stride_dvb + off_h * stride_dvh DV += off_b * stride_dvb + off_h * stride_dvh
if BIAS_TYPE != 'none':
Bias += off_b * stride_bb + off_h * stride_bh
# pointer to row-wise quantities in value-like data # pointer to row-wise quantities in value-like data
D += off_hb * seqlen_q_rounded D += off_hb * seqlen_q_rounded
LSE += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded
...@@ -452,12 +517,15 @@ def _bwd_kernel( ...@@ -452,12 +517,15 @@ def _bwd_kernel(
for start_n in range(0, num_block_n): for start_n in range(0, num_block_n):
_bwd_kernel_one_col_block( _bwd_kernel_one_col_block(
start_n, start_n,
Q, K, V, softmax_scale, Q, K, V, Bias,
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm,
stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, headdim, seqlen_q, seqlen_k, headdim,
ATOMIC_ADD=False, ATOMIC_ADD=False,
BIAS_TYPE=BIAS_TYPE,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
...@@ -467,12 +535,15 @@ def _bwd_kernel( ...@@ -467,12 +535,15 @@ def _bwd_kernel(
start_n = tl.program_id(0) start_n = tl.program_id(0)
_bwd_kernel_one_col_block( _bwd_kernel_one_col_block(
start_n, start_n,
Q, K, V, softmax_scale, Q, K, V, Bias,
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, softmax_scale,
stride_qm, stride_kn, stride_vn, stride_bm,
stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, headdim, seqlen_q, seqlen_k, headdim,
ATOMIC_ADD=True, ATOMIC_ADD=True,
BIAS_TYPE=BIAS_TYPE,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
...@@ -480,7 +551,7 @@ def _bwd_kernel( ...@@ -480,7 +551,7 @@ def _bwd_kernel(
) )
def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
# shape constraints # shape constraints
batch, seqlen_q, nheads, d = q.shape batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape _, seqlen_k, _, _ = k.shape
...@@ -491,10 +562,31 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -491,10 +562,31 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
assert q.is_cuda and k.is_cuda and v.is_cuda assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(d) softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
has_bias = bias is not None
bias_type = 'none'
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
if bias.stride(-1) != 1:
bias = bias.contiguous()
if bias.shape[2:] == (1, seqlen_k):
bias_type = 'vector'
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = 'matrix'
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
if bias.shape[:2] == (1, nheads):
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
elif bias.shape[:2] == (batch, 1):
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
# lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device,
# dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
o = torch.empty_like(q) o = torch.empty_like(q)
...@@ -503,18 +595,19 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -503,18 +595,19 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
# num_warps = 4 if d <= 64 else 8 # num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_kernel[grid]( _fwd_kernel[grid](
q, k, v, o, q, k, v, bias, o,
lse, tmp, lse, tmp,
softmax_scale, softmax_scale,
q.stride(0), q.stride(2), q.stride(1), q.stride(0), q.stride(2), q.stride(1),
k.stride(0), k.stride(2), k.stride(1), k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1), v.stride(0), v.stride(2), v.stride(1),
*bias_strides,
o.stride(0), o.stride(2), o.stride(1), o.stride(0), o.stride(2), o.stride(1),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal, BLOCK_HEADDIM, bias_type, causal, BLOCK_HEADDIM,
# BLOCK_M=BLOCK, BLOCK_N=BLOCK, # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
# num_warps=num_warps, # num_warps=num_warps,
# num_stages=1, # num_stages=1,
...@@ -522,7 +615,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -522,7 +615,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
return o, lse, softmax_scale # softmax_scale could have been updated return o, lse, softmax_scale # softmax_scale could have been updated
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_scale=None): def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
# Make sure that the last dimension is contiguous # Make sure that the last dimension is contiguous
if do.stride(-1) != 1: if do.stride(-1) != 1:
do = do.contiguous() do = do.contiguous()
...@@ -532,6 +625,8 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -532,6 +625,8 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
assert d <= 128 assert d <= 128
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded) assert lse.shape == (batch, nheads, seqlen_q_rounded)
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
softmax_scale = softmax_scale or 1.0 / math.sqrt(d) softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
# dq_accum = torch.zeros_like(q, dtype=torch.float32) # dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum = torch.empty_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32)
...@@ -548,19 +643,41 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -548,19 +643,41 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
) )
has_bias = bias is not None
bias_type = 'none'
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
assert bias.stride(-1) == 1
if bias.shape[2:] == (1, seqlen_k):
bias_type = 'vector'
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = 'matrix'
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
if bias.shape[:2] == (1, nheads):
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
elif bias.shape[:2] == (batch, 1):
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
# BLOCK_M = 128 # BLOCK_M = 128
# BLOCK_N = 64 # BLOCK_N = 64
# num_warps = 4 # num_warps = 4
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
batch * nheads) batch * nheads)
_bwd_kernel[grid]( _bwd_kernel[grid](
q, k, v, q, k, v, bias,
do, dq_accum, dk, dv, do, dq_accum, dk, dv,
lse, delta, lse, delta,
softmax_scale, softmax_scale,
q.stride(0), q.stride(2), q.stride(1), q.stride(0), q.stride(2), q.stride(1),
k.stride(0), k.stride(2), k.stride(1), k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1), v.stride(0), v.stride(2), v.stride(1),
*bias_strides,
do.stride(0), do.stride(2), do.stride(1), do.stride(0), do.stride(2), do.stride(1),
dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(0), dk.stride(2), dk.stride(1),
...@@ -569,7 +686,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -569,7 +686,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal, BLOCK_HEADDIM, bias_type, causal, BLOCK_HEADDIM,
# SEQUENCE_PARALLEL=False, # SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps, # num_warps=num_warps,
...@@ -581,31 +698,36 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -581,31 +698,36 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
class FlashAttnQKVPackedFunc(torch.autograd.Function): class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, qkv, causal=False, softmax_scale=None): def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
""" """
qkv: (batch, seqlen, 3, nheads, headdim) qkv: (batch, seqlen, 3, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
""" """
# Make sure that the last dimension is contiguous # Make sure that the last dimension is contiguous
if qkv.stride(-1) != 1: if qkv.stride(-1) != 1:
qkv = qkv.contiguous() qkv = qkv.contiguous()
o, lse, ctx.softmax_scale = _flash_attn_forward( o, lse, ctx.softmax_scale = _flash_attn_forward(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], causal=causal, softmax_scale=softmax_scale qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal,
softmax_scale=softmax_scale
) )
ctx.save_for_backward(qkv, o, lse) ctx.save_for_backward(qkv, o, lse, bias)
ctx.causal = causal ctx.causal = causal
return o return o
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
qkv, o, lse = ctx.saved_tensors qkv, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode(): with torch.inference_mode():
dqkv = torch.empty_like(qkv) dqkv = torch.empty_like(qkv)
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
causal=ctx.causal, softmax_scale=ctx.softmax_scale) bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dqkv, None, None return dqkv, None, None, None
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
...@@ -614,23 +736,27 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply ...@@ -614,23 +736,27 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
class FlashAttnKVPackedFunc(torch.autograd.Function): class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, kv, causal=False, softmax_scale=None): def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
""" """
q: (batch, seqlen, nheads, headdim) q: (batch, seqlen_q, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim) kv: (batch, seqlen_k, 2, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
""" """
# Make sure that the last dimension is contiguous # Make sure that the last dimension is contiguous
q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
o, lse, ctx.softmax_scale = _flash_attn_forward( o, lse, ctx.softmax_scale = _flash_attn_forward(
q, kv[:, :, 0], kv[:, :, 1], causal=causal, softmax_scale=softmax_scale q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
) )
ctx.save_for_backward(q, kv, o, lse) ctx.save_for_backward(q, kv, o, lse, bias)
ctx.causal = causal ctx.causal = causal
return o return o
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
q, kv, o, lse = ctx.saved_tensors q, kv, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode(): with torch.inference_mode():
...@@ -638,8 +764,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): ...@@ -638,8 +764,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
dkv = torch.empty_like(kv) dkv = torch.empty_like(kv)
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse, _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
dq, dkv[:, :, 0], dkv[:, :, 1], dq, dkv[:, :, 0], dkv[:, :, 1],
causal=ctx.causal, softmax_scale=ctx.softmax_scale) bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dkv, None, None return dq, dkv, None, None, None
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
...@@ -648,21 +774,27 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply ...@@ -648,21 +774,27 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
class FlashAttnFunc(torch.autograd.Function): class FlashAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, q, k, v, causal=False, softmax_scale=None): def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
""" """
q, k, v: (batch_size, seqlen, nheads, headdim) q: (batch_size, seqlen_q, nheads, headdim)
k, v: (batch_size, seqlen_k, nheads, headdim)
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
""" """
# Make sure that the last dimension is contiguous # Make sure that the last dimension is contiguous
q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, causal=causal, o, lse, ctx.softmax_scale = _flash_attn_forward(
softmax_scale=softmax_scale) q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
ctx.save_for_backward(q, k, v, o, lse) )
ctx.save_for_backward(q, k, v, o, lse, bias)
ctx.causal = causal ctx.causal = causal
return o return o
@staticmethod @staticmethod
def backward(ctx, do): def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors q, k, v, o, lse, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet'
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode(): with torch.inference_mode():
...@@ -670,8 +802,8 @@ class FlashAttnFunc(torch.autograd.Function): ...@@ -670,8 +802,8 @@ class FlashAttnFunc(torch.autograd.Function):
dk = torch.empty_like(k) dk = torch.empty_like(k)
dv = torch.empty_like(v) dv = torch.empty_like(v)
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale) bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dk, dv, None, None return dq, dk, dv, None, None, None
flash_attn_func = FlashAttnFunc.apply flash_attn_func = FlashAttnFunc.apply
...@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None ...@@ -122,7 +122,7 @@ def generate_qkv(x, Wqkv, nheads, query_padding_mask=None, key_padding_mask=None
def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0, def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropout_p=0.0,
dropout_mask=None, causal=False, upcast=True, reorder_ops=False): dropout_mask=None, causal=False, bias=None, upcast=True, reorder_ops=False):
""" """
Arguments: Arguments:
q: (batch_size, seqlen_q, nheads, head_dim) q: (batch_size, seqlen_q, nheads, head_dim)
...@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -132,6 +132,7 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
key_padding_mask: (batch_size, seqlen_k) key_padding_mask: (batch_size, seqlen_k)
dropout_p: float dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
bias: (batch_size, nheads, seqlen_q, seqlen_k)
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16. output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
...@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo ...@@ -150,6 +151,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k) scores = torch.einsum('bthd,bshd->bhts', q / math.sqrt(d), k)
else: else:
scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d)) scores = torch.einsum('bthd,bshd->bhts', q, k / math.sqrt(d))
if bias is not None:
scores = (scores + bias).to(dtype=scores.dtype)
if key_padding_mask is not None: if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf')) scores.masked_fill_(rearrange(~key_padding_mask, 'b s -> b 1 1 s'), float('-inf'))
if causal: if causal:
...@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func ...@@ -863,11 +866,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [48]) # @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): @pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
# @pytest.mark.parametrize('bias_shape', (['1h1k']))
def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = 'cuda'
...@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -877,12 +882,23 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
nheads = 4 nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
if bias_shape == '1h1k':
bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == '1hqk':
bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b11k':
bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b1qk':
bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device)
else:
bias = None
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output = flash_attn_func(q, k, v, causal) output = flash_attn_func(q, k, v, bias, causal)
output_ref, attn_ref = attention_ref(q, k, v, causal=causal) output_ref, attn_ref = attention_ref(q, k, v, bias=bias, causal=causal)
output_pt, attn_pt = attention_ref(q, k, v, causal=causal, upcast=False, reorder_ops=True) output_pt, attn_pt = attention_ref(q, k, v, bias=bias, causal=causal, upcast=False,
reorder_ops=True)
print(f'Output max diff: {(output - output_ref).abs().max().item()}') print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
...@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -919,13 +935,14 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True]) # @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
@pytest.mark.parametrize('d', [64, 128]) # @pytest.mark.parametrize('d', [96])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)])
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): @pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = 'cuda'
...@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -935,19 +952,31 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
nheads = 4 nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
if bias_shape == '1h1k':
bias = torch.randn(1, nheads, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == '1hqk':
bias = torch.randn(1, nheads, seqlen_q, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b11k':
bias = torch.randn(batch_size, 1, 1, seqlen_k, dtype=torch.float, device=device)
elif bias_shape == 'b1qk':
bias = torch.randn(batch_size, 1, seqlen_q, seqlen_k, dtype=torch.float, device=device)
else:
bias = None
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output_0 = flash_attn_func(q, k, v, causal) output_0 = flash_attn_func(q, k, v, bias, causal)
g = torch.randn_like(output_0) g = torch.randn_like(output_0)
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g) dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic # The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq = False deterministic_dq = False
equal_fn = (torch.equal if deterministic_dq # Numerical error if we just do any arithmetic on dq
else partial(torch.allclose, atol=1e-3 if dtype == torch.bfloat16 else 1e-5)) dq_atol = ((dq_0 + 0.3 - 0.3) - dq_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
# Run 10000 times and check that the results don't change
for i in range(10000): for i in range(10000):
output = flash_attn_func(q, k, v, causal) output = flash_attn_func(q, k, v, None, causal)
output_equal = torch.equal(output, output_0) output_equal = torch.equal(output, output_0)
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
print(f'Output max diff: {(output - output_0).abs().max().item()}') print(f'Output max diff: {(output - output_0).abs().max().item()}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment