Unverified Commit b90b638d authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Provide pre-computed max sequence to remove unnecessary kernels and D2H copies (#555)



* Provide pre-computed max sequence to remove unnecessary kernels and D2H copies
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Tweak comments
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent cd798c97
...@@ -1304,6 +1304,8 @@ class FlashAttention(torch.nn.Module): ...@@ -1304,6 +1304,8 @@ class FlashAttention(torch.nn.Module):
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
cp_group: Optional[dist_group_type] = None, cp_group: Optional[dist_group_type] = None,
...@@ -1346,10 +1348,10 @@ class FlashAttention(torch.nn.Module): ...@@ -1346,10 +1348,10 @@ class FlashAttention(torch.nn.Module):
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size = query_layer.shape[0]
query_layer.shape[0], query_layer.shape[1], key_layer.shape[1])
if qkv_format in ['sbhd', 'bshd']: if qkv_format in ['sbhd', 'bshd']:
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
if not context_parallel: if not context_parallel:
# [b * s, h, d] # [b * s, h, d]
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
...@@ -1422,10 +1424,12 @@ class FlashAttention(torch.nn.Module): ...@@ -1422,10 +1424,12 @@ class FlashAttention(torch.nn.Module):
), "flash-attn v2 is required for variable sequence length support!" ), "flash-attn v2 is required for variable sequence length support!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] if max_seqlen_q is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = seqlens_q.max().item() max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item() if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
if context_parallel: if context_parallel:
assert ( assert (
...@@ -1754,6 +1758,8 @@ class FusedAttention(torch.nn.Module): ...@@ -1754,6 +1758,8 @@ class FusedAttention(torch.nn.Module):
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
fused_attention_backend: fused_attention_backend:
...@@ -2104,6 +2110,8 @@ class DotProductAttention(torch.nn.Module): ...@@ -2104,6 +2110,8 @@ class DotProductAttention(torch.nn.Module):
qkv_format: Optional[str] = None, qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
...@@ -2176,6 +2184,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -2176,6 +2184,12 @@ class DotProductAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`, Cumulative sum of sequence lengths in a batch for `key_layer` and `value_layer`,
with shape [batch_size + 1] and dtype torch.int32. with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`, attn_mask_type: {`no_mask`, `padding`, `causal`, `padding,causal`, `causal,padding`,
`arbitrary`}, default = `None`. Type of attention mask passed into `arbitrary`}, default = `None`. Type of attention mask passed into
softmax operation. 'padding,causal' and 'causal,padding' are equivalent. softmax operation. 'padding,causal' and 'causal,padding' are equivalent.
...@@ -2238,10 +2252,12 @@ class DotProductAttention(torch.nn.Module): ...@@ -2238,10 +2252,12 @@ class DotProductAttention(torch.nn.Module):
assert (cu_seqlens_q.dtype == torch.int32 assert (cu_seqlens_q.dtype == torch.int32
and cu_seqlens_kv.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32
), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] if max_seqlen_q is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = seqlens_q.max().item() max_seqlen_q = seqlens_q.max().item()
max_seqlen_kv = seqlens_kv.max().item() if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
if qkv_format in ['sbhd', 'bshd']: if qkv_format in ['sbhd', 'bshd']:
assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)) assert (all(len(x.shape) == 4 for x in (query_layer, key_layer, value_layer))
...@@ -2405,7 +2421,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2405,7 +2421,9 @@ class DotProductAttention(torch.nn.Module):
window_size=window_size, window_size=window_size,
cp_group=self.cp_group, cp_group=self.cp_group,
cp_global_ranks=self.cp_global_ranks, cp_global_ranks=self.cp_global_ranks,
cp_stream=self.cp_stream) cp_stream=self.cp_stream,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
assert ( assert (
self.cp_group is None or get_distributed_world_size(self.cp_group) == 1 self.cp_group is None or get_distributed_world_size(self.cp_group) == 1
...@@ -2428,7 +2446,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2428,7 +2446,9 @@ class DotProductAttention(torch.nn.Module):
fused_attention_backend = fused_attention_backend, fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias, core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill) fast_zero_fill = fast_zero_fill,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
return self.fused_attention(query_layer, key_layer, value_layer, return self.fused_attention(query_layer, key_layer, value_layer,
qkv_layout = qkv_layout, qkv_layout = qkv_layout,
cu_seqlens_q = cu_seqlens_q, cu_seqlens_q = cu_seqlens_q,
...@@ -2438,7 +2458,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -2438,7 +2458,9 @@ class DotProductAttention(torch.nn.Module):
fused_attention_backend = fused_attention_backend, fused_attention_backend = fused_attention_backend,
core_attention_bias_type = core_attention_bias_type, core_attention_bias_type = core_attention_bias_type,
core_attention_bias = core_attention_bias, core_attention_bias = core_attention_bias,
fast_zero_fill = fast_zero_fill) fast_zero_fill = fast_zero_fill,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv)
if _NVTE_DEBUG: if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA") print("[DotProductAttention]: using unfused DPA")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment