"googlemock/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "2e308484d9693f8251748c295f6ed7ed25d767eb"
Unverified Commit fad3044b authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

Avoid redundant computation for cu_seqlens (#535)



avoid redundant computation for cu_seqlens
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 82555b3f
......@@ -1398,20 +1398,24 @@ class FlashAttention(torch.nn.Module):
query_layer_packed, key_layer_packed, value_layer_packed)
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
if self.layer_number == 1:
if cu_seqlens_q is None:
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device)
if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange(
0,
(batch_size + 1) * max_seqlen_kv,
step=max_seqlen_kv,
dtype=torch.int32,
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd':
assert not context_parallel, "thd format is not supported for context parallelism!"
assert (_flash_attn_2_available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment