Commit f1a50a3c authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

Flash Attention inference fix

parent 717c5274
...@@ -363,16 +363,42 @@ class FlashSelfAttention(torch.nn.Module): ...@@ -363,16 +363,42 @@ class FlashSelfAttention(torch.nn.Module):
""" """
assert q.dtype in [torch.float16, torch.bfloat16] assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda assert q.is_cuda
batch_size, seqlen = q.shape[0], q.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] batch_size, seqlen_q = q.shape[0], q.shape[1]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, if self.training:
device=q.device) # during training q,k,v all have same seqlen
output = flash_attn_unpadded_func( q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
self.dropout_p if self.training else 0.0, device=q.device)
softmax_scale=self.softmax_scale, causal=self.causal
) output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, seqlen_q, seqlen_q,
self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
else:
# during inference q seqlen is different than k,v seqlen
assert k.dtype in [torch.float16, torch.bfloat16]
assert k.is_cuda
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step do q,k,v have same seqlen
seqlen_k = k.shape[1]
is_causal = seqlen_q == seqlen_k
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
0.0,
softmax_scale=self.softmax_scale, causal=is_causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment