Commit cdd2afdf authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'flash_attention_inference' into 'main'

Flash Attention inference fix

See merge request ADLR/megatron-lm!562
parents 717c5274 a7cce320
...@@ -361,18 +361,37 @@ class FlashSelfAttention(torch.nn.Module): ...@@ -361,18 +361,37 @@ class FlashSelfAttention(torch.nn.Module):
--------- ---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D) q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
""" """
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
batch_size, seqlen = q.shape[0], q.shape[1] assert all((i.is_cuda for i in (q,k,v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
max_s = seqlen cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=q.device)
device=q.device)
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
self.dropout_p = 0
output = flash_attn_unpadded_func( output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p if self.training else 0.0, self.dropout_p,
softmax_scale=self.softmax_scale, causal=self.causal softmax_scale=self.softmax_scale, causal=is_causal
) )
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment