Commit 254e8815 authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

refactor flash attention

parent f1a50a3c
......@@ -361,43 +361,35 @@ class FlashSelfAttention(torch.nn.Module):
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
if self.training:
# during training q,k,v all have same seqlen
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, seqlen_q, seqlen_q,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
else:
# during inference q seqlen is different than k,v seqlen
assert k.dtype in [torch.float16, torch.bfloat16]
assert k.is_cuda
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step do q,k,v have same seqlen
seqlen_k = k.shape[1]
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0,
softmax_scale=self.softmax_scale, causal=is_causal
)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0,
softmax_scale=self.softmax_scale, causal=is_causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment