Unverified Commit f0ed3d50 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Avoid using LRU cache for cu_seqlens (#798)



* Try using global buffer for cu_seqlens
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Avoid using functools.lru_cache
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent dac00019
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Attention.""" """Attention."""
import collections import collections
from contextlib import nullcontext from contextlib import nullcontext
import functools
from importlib.metadata import version from importlib.metadata import version
import math import math
import os import os
...@@ -278,8 +277,7 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: ...@@ -278,8 +277,7 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
return indices return indices
_cu_seqlens_cache = {}
@functools.lru_cache
def _get_full_cu_seqlens( def _get_full_cu_seqlens(
batch_size: int, batch_size: int,
max_seqlen: int, max_seqlen: int,
...@@ -290,13 +288,16 @@ def _get_full_cu_seqlens( ...@@ -290,13 +288,16 @@ def _get_full_cu_seqlens(
All sequences in batch have the maximum sequence length. All sequences in batch have the maximum sequence length.
""" """
return torch.arange( global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
0, 0,
(batch_size + 1) * max_seqlen, (batch_size + 1) * max_seqlen,
step=max_seqlen, step=max_seqlen,
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
return _cu_seqlens_cache[(batch_size, max_seqlen)]
@jit_fuser @jit_fuser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment