Unverified Commit da30634a authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Refactor caching of cumulative sequence lengths (#630)



Do not cache sequence lengths based on layer number
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d68028c8
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
"""Attention.""" """Attention."""
import collections import collections
from contextlib import nullcontext
import functools
from importlib.metadata import version
import math
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
import math
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, List, Optional, Tuple, Union, Dict
from pkg_resources import packaging
import numpy as np import numpy as np
from pkg_resources import packaging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -68,7 +70,6 @@ if _flash_attn_version >= _flash_attn_version_required: ...@@ -68,7 +70,6 @@ if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
_cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv = None, None, None, None
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) _NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
...@@ -214,6 +215,26 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: ...@@ -214,6 +215,26 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
return indices return indices
@functools.lru_cache
def _get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
device: torch.device,
) -> torch.Tensor:
"""Cumulative sequence lengths in full data batch
All sequences in batch have the maximum sequence length.
"""
return torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
@jit_fuser @jit_fuser
def pack_tensor( def pack_tensor(
indices: torch.Tensor, indices: torch.Tensor,
...@@ -1652,7 +1673,6 @@ class FlashAttention(torch.nn.Module): ...@@ -1652,7 +1673,6 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [x.contiguous() query_layer, key_layer, value_layer = [x.contiguous()
for x in (query_layer, key_layer, value_layer)] for x in (query_layer, key_layer, value_layer)]
global _cu_seqlens_q, _cu_seqlens_kv, _indices_q, _indices_kv
batch_size = query_layer.shape[0] batch_size = query_layer.shape[0]
if qkv_format in ['sbhd', 'bshd']: if qkv_format in ['sbhd', 'bshd']:
...@@ -1671,58 +1691,45 @@ class FlashAttention(torch.nn.Module): ...@@ -1671,58 +1691,45 @@ class FlashAttention(torch.nn.Module):
assert ( assert (
max_seqlen_q == max_seqlen_kv max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same." ), "Maximum sequence length for Q and KV should be the same."
if self.layer_number == 1:
if cu_seqlens_q is None: if cu_seqlens_q is None:
assert (attention_mask is not None assert (attention_mask is not None
), "Please provide attention_mask for padding!" ), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask) cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask)
else: else:
_cu_seqlens_q = cu_seqlens_q indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q) cu_seqlens_kv = cu_seqlens_q
_cu_seqlens_kv = _cu_seqlens_q query_layer, key_layer, value_layer = PackTensors.apply(
query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply( indices_q, query_layer, key_layer, value_layer
_indices_q, query_layer, key_layer, value_layer
) )
else: else:
if self.layer_number == 1:
if cu_seqlens_q is None or cu_seqlens_kv is None: if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None assert (attention_mask is not None
), "Please provide attention_mask for padding!" ), "Please provide attention_mask for padding!"
_cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices( cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(
attention_mask[0]) attention_mask[0])
_cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices( cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(
attention_mask[1]) attention_mask[1])
else: else:
_cu_seqlens_q = cu_seqlens_q indices_q = get_indices(max_seqlen_q, cu_seqlens_q)
_cu_seqlens_kv = cu_seqlens_kv indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv)
_indices_q = get_indices(max_seqlen_q, cu_seqlens_q) query_layer = PackTensors.apply(indices_q, query_layer)
_indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) key_layer, value_layer = PackTensors.apply(
query_layer_packed = PackTensors.apply(_indices_q, query_layer) indices_kv, key_layer, value_layer
key_layer_packed, value_layer_packed = PackTensors.apply( )
_indices_kv, key_layer, value_layer
)
query_layer, key_layer, value_layer = (
query_layer_packed, key_layer_packed, value_layer_packed)
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
else: else:
if self.layer_number == 1: # Cumulative sequence lengths for unpadded data
if cu_seqlens_q is None: if cu_seqlens_q is None:
cu_seqlens_q = torch.arange( cu_seqlens_q = _get_full_cu_seqlens(
0, batch_size,
(batch_size + 1) * max_seqlen_q, max_seqlen_q,
step=max_seqlen_q, query_layer.device,
dtype=torch.int32, )
device=query_layer.device)
if cu_seqlens_kv is None: if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange( cu_seqlens_kv = _get_full_cu_seqlens(
0, batch_size,
(batch_size + 1) * max_seqlen_kv, max_seqlen_kv,
step=max_seqlen_kv, key_layer.device,
dtype=torch.int32, )
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
elif qkv_format == 'thd': elif qkv_format == 'thd':
assert not context_parallel, "thd format not supported with context parallelism!" assert not context_parallel, "thd format not supported with context parallelism!"
assert (cu_seqlens_q is not None and cu_seqlens_kv is not None assert (cu_seqlens_q is not None and cu_seqlens_kv is not None
...@@ -1777,7 +1784,7 @@ class FlashAttention(torch.nn.Module): ...@@ -1777,7 +1784,7 @@ class FlashAttention(torch.nn.Module):
) )
if 'padding' in attn_mask_type: if 'padding' in attn_mask_type:
output = UnpackTensor.apply(_indices_q, batch_size * max_seqlen_q, output) output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
if qkv_format == 'sbhd': if qkv_format == 'sbhd':
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
...@@ -2135,42 +2142,30 @@ class FusedAttention(torch.nn.Module): ...@@ -2135,42 +2142,30 @@ class FusedAttention(torch.nn.Module):
if 'padding' in attn_mask_type: if 'padding' in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!" assert not context_parallel, "Padding mask not supported with context parallelism!"
global _cu_seqlens_q, _cu_seqlens_kv if cu_seqlens_q is None or cu_seqlens_kv is None:
if (cu_seqlens_q is not None and cu_seqlens_kv is not None): if attention_mask is None:
# use cu_seqlens when both cu_seqlens and attention_mask are present raise RuntimeError(
if self.layer_number == 1: "Please provide attention_mask or cu_seqlens for padding!"
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv )
elif attention_mask is not None:
if self.attention_type == "self": if self.attention_type == "self":
if self.layer_number == 1: cu_seqlens_q = get_cu_seqlens(attention_mask)
_cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_kv = cu_seqlens_q
_cu_seqlens_kv = _cu_seqlens_q
else:
if self.layer_number == 1:
_cu_seqlens_q = get_cu_seqlens(attention_mask[0])
_cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else: else:
raise Exception("Please provide attention_mask or cu_seqlens for padding!") cu_seqlens_q = get_cu_seqlens(attention_mask[0])
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv cu_seqlens_kv = get_cu_seqlens(attention_mask[1])
else: else:
if self.layer_number == 1:
if cu_seqlens_q is None: if cu_seqlens_q is None:
cu_seqlens_q = torch.arange( cu_seqlens_q = _get_full_cu_seqlens(
0, batch_size,
(batch_size + 1) * max_seqlen_q, max_seqlen_q,
step=max_seqlen_q, query_layer.device,
dtype=torch.int32, )
device=query_layer.device)
if cu_seqlens_kv is None: if cu_seqlens_kv is None:
cu_seqlens_kv = torch.arange( cu_seqlens_kv = _get_full_cu_seqlens(
0, batch_size,
(batch_size + 1) * max_seqlen_kv, max_seqlen_kv,
step=max_seqlen_kv, key_layer.device,
dtype=torch.int32, )
device=key_layer.device)
_cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv
else:
cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment