Unverified Commit fad3044b authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Avoid redundant computation for cu_seqlens (#535)



avoid redundant computation for cu_seqlens
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 82555b3f
...@@ -1398,6 +1398,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1398,6 +1398,7 @@ class FlashAttention(torch.nn.Module):
query_layer_packed, key_layer_packed, value_layer_packed) query_layer_packed, key_layer_packed, value_layer_packed)
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else: else:
if self.layer_number == 1:
if cu_seqlens_q is None: if cu_seqlens_q is None:
cu_seqlens_q = torch.arange( cu_seqlens_q = torch.arange(
0, 0,
...@@ -1412,6 +1413,9 @@ class FlashAttention(torch.nn.Module): ...@@ -1412,6 +1413,9 @@ class FlashAttention(torch.nn.Module):
step=max_seqlen_kv, step=max_seqlen_kv,
dtype=torch.int32, dtype=torch.int32,
device=key_layer.device) device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd': elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!" assert not context_parallel, "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available assert (_flash_attn_2_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment