Commit e78d509c authored by Tri Dao's avatar Tri Dao
Browse files

[WIP] Support all head dimensions up to 128 in the Triton bwd

WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
parent 008951f1
...@@ -6,7 +6,9 @@ Changes: ...@@ -6,7 +6,9 @@ Changes:
- Implement both causal and non-causal attention. - Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention). - Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), in the forward pass. - [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require
more testing since there seems to be some race conditions due to the Triton compiler.
- Speed up the forward pass a bit, and only store the LSE instead of m and l. - Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling. - Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of - Optionally parallelize the backward pass across seqlen_k, to deal with the case of
...@@ -175,18 +177,12 @@ def _fwd_kernel( ...@@ -175,18 +177,12 @@ def _fwd_kernel(
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
}
)
@triton.jit @triton.jit
def _bwd_preprocess_do_o_dot( def _bwd_preprocess_do_o_dot(
Out, DO, Delta, Out, DO, Delta,
stride_ob, stride_oh, stride_om, stride_ob, stride_oh, stride_om,
stride_dob, stride_doh, stride_dom, stride_dob, stride_doh, stride_dom,
nheads, seqlen_q, seqlen_q_rounded, nheads, seqlen_q, seqlen_q_rounded, headdim,
EVEN_M: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
): ):
start_m = tl.program_id(0) start_m = tl.program_id(0)
...@@ -197,14 +193,10 @@ def _bwd_preprocess_do_o_dot( ...@@ -197,14 +193,10 @@ def _bwd_preprocess_do_o_dot(
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM) offs_d = tl.arange(0, BLOCK_HEADDIM)
# load # load
if EVEN_M:
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]).to(tl.float32)
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :]).to(tl.float32)
else:
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
mask=offs_m[:, None] < seqlen_q, other=0.0).to(tl.float32) mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
mask=offs_m[:, None] < seqlen_q, other=0.0).to(tl.float32) mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32)
delta = tl.sum(o * do, axis=1) delta = tl.sum(o * do, axis=1)
# write-back # write-back
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
...@@ -217,11 +209,11 @@ def _bwd_kernel_one_col_block( ...@@ -217,11 +209,11 @@ def _bwd_kernel_one_col_block(
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, seqlen_q, seqlen_k, headdim,
ATOMIC_ADD: tl.constexpr, ATOMIC_ADD: tl.constexpr,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
...@@ -230,13 +222,13 @@ def _bwd_kernel_one_col_block( ...@@ -230,13 +222,13 @@ def _bwd_kernel_one_col_block(
offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_qm = begin_m + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_HEADDIM) offs_d = tl.arange(0, BLOCK_HEADDIM)
# initialize pointers to value-like data # initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :]) q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_k[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_k[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
# initialize dv amd dk # initialize dv amd dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
...@@ -244,11 +236,21 @@ def _bwd_kernel_one_col_block( ...@@ -244,11 +236,21 @@ def _bwd_kernel_one_col_block(
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.load(k_ptrs), we get the wrong output! # if we just call tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M: if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs) k = tl.load(k_ptrs)
v = tl.load(v_ptrs) v = tl.load(v_ptrs)
else: else:
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
else:
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
# loop over rows # loop over rows
num_block_m = tl.cdiv(seqlen_q, BLOCK_M) num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
...@@ -256,24 +258,50 @@ def _bwd_kernel_one_col_block( ...@@ -256,24 +258,50 @@ def _bwd_kernel_one_col_block(
offs_m_curr = start_m + offs_m offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip # load q, k, v, do on-chip
if EVEN_M: if EVEN_M:
if EVEN_HEADDIM:
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
else: else:
q = tl.load(q_ptrs, mask=(offs_d[None, :] < headdim))
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
# recompute p = softmax(qk, dim=-1).T # recompute p = softmax(qk, dim=-1).T
qk = tl.dot(q, k, trans_b=True) qk = tl.dot(q, k, trans_b=True)
if not EVEN_N: # Need to mask out otherwise the softmax is wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
if IS_CAUSAL: if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
if not EVEN_HEADDIM:
tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr) lse_i = tl.load(LSE + offs_m_curr)
p = tl.exp(qk * softmax_scale - lse_i[:, None]) p = tl.exp(qk * softmax_scale - lse_i[:, None])
# compute dv # compute dv
if EVEN_M: # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
# in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
# the output is correct.
if EVEN_M & EVEN_HEADDIM:
do = tl.load(do_ptrs) do = tl.load(do_ptrs)
# if EVEN_M:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else: else:
if EVEN_HEADDIM:
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else:
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
dv += tl.dot(p.to(do.dtype), do, trans_a=True) dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do) # compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
if not EVEN_HEADDIM:
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True) dp = tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None]) # compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster # Putting the subtraction after the dp matmul (instead of before) is slightly faster
...@@ -286,34 +314,64 @@ def _bwd_kernel_one_col_block( ...@@ -286,34 +314,64 @@ def _bwd_kernel_one_col_block(
# compute dq # compute dq
if not ATOMIC_ADD: if not ATOMIC_ADD:
if EVEN_M: if EVEN_M:
if EVEN_HEADDIM:
dq = tl.load(dq_ptrs, eviction_policy="evict_last") dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k) dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last") tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else: else:
dq = tl.load(dq_ptrs, mask=offs_d[None, :] < headdim, other=0.0,
eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim, eviction_policy="evict_last")
else:
if EVEN_HEADDIM:
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
eviction_policy="evict_last") eviction_policy="evict_last")
dq += tl.dot(ds, k) dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
eviction_policy="evict_last") eviction_policy="evict_last")
else:
dq = tl.load(dq_ptrs,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
eviction_policy="evict_last")
else: # If we're parallelizing across the seqlen_k dimension else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k) dq = tl.dot(ds, k)
if EVEN_M: if EVEN_M:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq) tl.atomic_add(dq_ptrs, dq)
else: else:
tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
else:
tl.atomic_add(dq_ptrs, dq,
mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
# increment pointers # increment pointers
dq_ptrs += BLOCK_M * stride_dqm dq_ptrs += BLOCK_M * stride_dqm
q_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_dom do_ptrs += BLOCK_M * stride_dom
# write-back # write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :]) dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
if EVEN_N: if EVEN_N:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv) tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk) tl.store(dk_ptrs, dk)
else: else:
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
else:
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
def init_to_zero(name): def init_to_zero(name):
...@@ -345,7 +403,8 @@ def init_to_zero(name): ...@@ -345,7 +403,8 @@ def init_to_zero(name):
@triton.heuristics( @triton.heuristics(
{ {
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
} }
) )
@triton.jit @triton.jit
...@@ -361,12 +420,12 @@ def _bwd_kernel( ...@@ -361,12 +420,12 @@ def _bwd_kernel(
stride_dqb, stride_dqh, stride_dqm, stride_dqb, stride_dqh, stride_dqm,
stride_dkb, stride_dkh, stride_dkn, stride_dkb, stride_dkh, stride_dkn,
stride_dvb, stride_dvh, stride_dvn, stride_dvb, stride_dvh, stride_dvn,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
): ):
off_hb = tl.program_id(1) off_hb = tl.program_id(1)
...@@ -392,11 +451,11 @@ def _bwd_kernel( ...@@ -392,11 +451,11 @@ def _bwd_kernel(
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, seqlen_q, seqlen_k, headdim,
ATOMIC_ADD=False, ATOMIC_ADD=False,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
) )
else: else:
...@@ -407,11 +466,11 @@ def _bwd_kernel( ...@@ -407,11 +466,11 @@ def _bwd_kernel(
DO, DQ, DK, DV, DO, DQ, DK, DV,
LSE, D, LSE, D,
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
seqlen_q, seqlen_k, seqlen_q, seqlen_k, headdim,
ATOMIC_ADD=True, ATOMIC_ADD=True,
IS_CAUSAL=IS_CAUSAL, IS_CAUSAL=IS_CAUSAL,
BLOCK_HEADDIM=BLOCK_HEADDIM, BLOCK_HEADDIM=BLOCK_HEADDIM,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
) )
...@@ -423,7 +482,6 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -423,7 +482,6 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
assert k.shape == (batch, seqlen_k, nheads, d) assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d)
assert d <= 128, 'FlashAttention only support head dimensions up to 128' assert d <= 128, 'FlashAttention only support head dimensions up to 128'
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
assert q.is_cuda and k.is_cuda and v.is_cuda assert q.is_cuda and k.is_cuda and v.is_cuda
...@@ -435,6 +493,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): ...@@ -435,6 +493,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
o = torch.empty_like(q) o = torch.empty_like(q)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
# BLOCK = 128 # BLOCK = 128
# num_warps = 4 if d <= 64 else 8 # num_warps = 4 if d <= 64 else 8
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
...@@ -464,20 +523,23 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -464,20 +523,23 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do = do.contiguous() do = do.contiguous()
batch, seqlen_q, nheads, d = q.shape batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape _, seqlen_k, _, _ = k.shape
assert d in {16, 32, 64, 128} # assert d in {16, 32, 64, 128}
assert d <= 128
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded) assert lse.shape == (batch, nheads, seqlen_q_rounded)
# dq_accum = torch.zeros_like(q, dtype=torch.float32) # dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum = torch.empty_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32)
delta = torch.empty_like(lse) delta = torch.empty_like(lse)
# delta = torch.zeros_like(lse) # delta = torch.zeros_like(lse)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_bwd_preprocess_do_o_dot[grid]( _bwd_preprocess_do_o_dot[grid](
o, do, delta, o, do, delta,
o.stride(0), o.stride(2), o.stride(1), o.stride(0), o.stride(2), o.stride(1),
do.stride(0), do.stride(2), do.stride(1), do.stride(0), do.stride(2), do.stride(1),
nheads, seqlen_q, seqlen_q_rounded, nheads, seqlen_q, seqlen_q_rounded, d,
BLOCK_M=128, BLOCK_HEADDIM=d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
) )
# TODO: There are 2 Memcpy DtoD when I use the autotuner. # TODO: There are 2 Memcpy DtoD when I use the autotuner.
...@@ -498,11 +560,11 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ ...@@ -498,11 +560,11 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
dk.stride(0), dk.stride(2), dk.stride(1), dk.stride(0), dk.stride(2), dk.stride(1),
dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(0), dv.stride(2), dv.stride(1),
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs # Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d, # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
causal, d, causal, BLOCK_HEADDIM,
# SEQUENCE_PARALLEL=False, # SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps, # num_warps=num_warps,
......
...@@ -861,18 +861,18 @@ from flash_attn.flash_attn_triton import flash_attn_func ...@@ -861,18 +861,18 @@ from flash_attn.flash_attn_triton import flash_attn_func
# @pytest.mark.parametrize('dtype', [torch.float16]) # @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [40, 64, 128, 88]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('d', [40])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1024)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM pytest.skip() # Reference implementation OOM
device = 'cuda' device = 'cuda'
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 32
nheads = 4 nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
...@@ -887,8 +887,6 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -887,8 +887,6 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
run_bwd = d in [16, 32, 64, 128]
if run_bwd:
g = torch.randn_like(output) g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
...@@ -896,16 +894,19 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): ...@@ -896,16 +894,19 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error # Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation. # of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if run_bwd:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment