"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e2c935f5615a3c15ee7439fa8a560edd5f13a457"
Unverified Commit 811a9caa authored by Guang Yang's avatar Guang Yang Committed by GitHub
Browse files

Make static cache compatible with torch.export (#32168)

parent 7f5d644e
...@@ -23,12 +23,14 @@ if is_hqq_available(): ...@@ -23,12 +23,14 @@ if is_hqq_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@dataclass class Cache(torch.nn.Module):
class Cache:
""" """
Base, abstract class for all caches. The actual data structure is specific to each subclass. Base, abstract class for all caches. The actual data structure is specific to each subclass.
""" """
def __init__(self):
super().__init__()
def update( def update(
self, self,
key_states: torch.Tensor, key_states: torch.Tensor,
...@@ -299,6 +301,7 @@ class DynamicCache(Cache): ...@@ -299,6 +301,7 @@ class DynamicCache(Cache):
""" """
def __init__(self) -> None: def __init__(self) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
...@@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache): ...@@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
""" """
def __init__(self, cache_config: QuantizedCacheConfig) -> None: def __init__(self, cache_config: QuantizedCacheConfig) -> None:
super().__init__()
self._quantized_key_cache: List[torch.Tensor] = [] self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = [] self._quantized_value_cache: List[torch.Tensor] = []
...@@ -634,6 +638,7 @@ class SinkCache(Cache): ...@@ -634,6 +638,7 @@ class SinkCache(Cache):
""" """
def __init__(self, window_length: int, num_sink_tokens: int) -> None: def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
self.window_length = window_length self.window_length = window_length
...@@ -786,7 +791,7 @@ class SinkCache(Cache): ...@@ -786,7 +791,7 @@ class SinkCache(Cache):
class StaticCache(Cache): class StaticCache(Cache):
""" """
Static Cache class to be used with `torch.compile(model)`. Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
Parameters: Parameters:
config (`PretrainedConfig): config (`PretrainedConfig):
...@@ -817,18 +822,22 @@ class StaticCache(Cache): ...@@ -817,18 +822,22 @@ class StaticCache(Cache):
self.key_cache: List[torch.Tensor] = [] self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for _ in range(config.num_hidden_layers): for idx in range(config.num_hidden_layers):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device))
key_cache = getattr(self, f"key_cache_{idx}")
value_cache = getattr(self, f"value_cache_{idx}")
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway) # it is not needed anyway)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
if not is_torchdynamo_compiling(): if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache) torch._dynamo.mark_static_address(value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(key_cache)
self.value_cache.append(new_layer_value_cache) self.value_cache.append(value_cache)
def update( def update(
self, self,
...@@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache): ...@@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
""" """
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
...@@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache): ...@@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
""" """
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
super().__init__()
self.self_attention_cache = self_attention_cache self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache self.cross_attention_cache = cross_attention_cache
...@@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache): ...@@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache):
class HybridCache(Cache): class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None: if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError( raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting " "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
......
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
import unittest import unittest
from packaging import version
from parameterized import parameterized from parameterized import parameterized
from transformers import set_seed from transformers import set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
is_torch_available, is_torch_available,
require_auto_gptq, require_auto_gptq,
require_read_token,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
slow, slow,
...@@ -32,6 +34,7 @@ if is_torch_available(): ...@@ -32,6 +34,7 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
DynamicCache, DynamicCache,
...@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase): ...@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase):
self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128))
@slow
@require_read_token
def test_static_cache_exportability(self):
"""
Tests that static cache works with `torch.export()`
"""
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
device = "cpu"
dtype = torch.float32
max_batch_size = 1
config = AutoConfig.from_pretrained(
"google/gemma-2b",
torch_dtype=dtype,
use_cache=True,
)
m = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
config=config,
torch_dtype=dtype,
attn_implementation="sdpa", # Export and ExecuTorch only works for SdpaAttention
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
inputs = tokenizer(["The best color is"], return_tensors="pt").to(device)["input_ids"]
class ExportatibleModelWithStaticCache(torch.nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
self.static_cache = StaticCache(
config=config, max_batch_size=max_batch_size, max_cache_len=config.max_length, device=device
)
def forward(self, tokens: torch.Tensor, input_pos: torch.Tensor):
outs = self.model(
input_ids=tokens,
attention_mask=None,
position_ids=input_pos.unsqueeze(0),
cache_position=input_pos,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits
set_seed(0)
with torch.no_grad():
from torch.export import ExportedProgram, export
model = ExportatibleModelWithStaticCache(config, m)
exported_program = export(model, args=(inputs,), kwargs={"input_pos": torch.arange(1)})
self.assertTrue(isinstance(exported_program, ExportedProgram))
@require_torch_gpu @require_torch_gpu
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment