Commit 254e8815 authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

refactor flash attention

parent f1a50a3c
...@@ -361,43 +361,35 @@ class FlashSelfAttention(torch.nn.Module): ...@@ -361,43 +361,35 @@ class FlashSelfAttention(torch.nn.Module):
--------- ---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D) q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
""" """
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((i.is_cuda for i in (q,k,v)))
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
if self.training: q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
# during training q,k,v all have same seqlen cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device) device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, seqlen_q, seqlen_q,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
else:
# during inference q seqlen is different than k,v seqlen
assert k.dtype in [torch.float16, torch.bfloat16]
assert k.is_cuda
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q
else:
# turn off FA causal mask after first inference autoregressive iteration # turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step do q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
seqlen_k = k.shape[1]
is_causal = seqlen_q == seqlen_k is_causal = seqlen_q == seqlen_k
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device) device=q.device)
output = flash_attn_unpadded_func( output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0, 0.0,
softmax_scale=self.softmax_scale, causal=is_causal softmax_scale=self.softmax_scale, causal=is_causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment