Unverified Commit f0ed3d50 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Avoid using LRU cache for cu_seqlens (#798)



* Try using global buffer for cu_seqlens
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Avoid using functools.lru_cache
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent dac00019
......@@ -5,7 +5,6 @@
"""Attention."""
import collections
from contextlib import nullcontext
import functools
from importlib.metadata import version
import math
import os
......@@ -278,8 +277,7 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:
return indices
@functools.lru_cache
_cu_seqlens_cache = {}
def _get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
......@@ -290,13 +288,16 @@ def _get_full_cu_seqlens(
All sequences in batch have the maximum sequence length.
"""
return torch.arange(
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]
@jit_fuser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment