Unverified Commit 3e70a207 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Static Cache: load models with MQA or GQA (#28975)

parent da20209d
...@@ -351,10 +351,12 @@ class StaticCache(Cache): ...@@ -351,10 +351,12 @@ class StaticCache(Cache):
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim) cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0 self.seen_tokens = 0
......
...@@ -35,14 +35,16 @@ if is_torch_available(): ...@@ -35,14 +35,16 @@ if is_torch_available():
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
DynamicCache, DynamicCache,
LlamaConfig,
LlamaForCausalLM, LlamaForCausalLM,
SinkCache, SinkCache,
StaticCache,
) )
@require_torch @require_torch
class CacheTest(unittest.TestCase): class CacheTest(unittest.TestCase):
def test_cache_equivalence(self): def test_dynamic_cache_retrocompatibility(self):
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache""" """Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
legacy_cache = () legacy_cache = ()
new_cache = DynamicCache() new_cache = DynamicCache()
...@@ -120,6 +122,48 @@ class CacheTest(unittest.TestCase): ...@@ -120,6 +122,48 @@ class CacheTest(unittest.TestCase):
) )
) )
def test_static_cache_mha_mqa_gqa(self):
"""
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
attention (MQA)
"""
def _random_kvs(config):
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
random_keys = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
random_values = torch.rand(
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
device=torch_device,
)
return random_keys, random_values
mha_config = LlamaConfig(num_attention_heads=32)
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mha_static_cache.update(
*_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = gqa_static_cache.update(
*_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
cached_keys, cached_values = mqa_static_cache.update(
*_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)}
)
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
@require_torch_gpu @require_torch_gpu
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment