"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "32883b310ba30d72e67bb2ebb5847888f03a90a8"
Unverified Commit 51ab25e2 authored by OsamaS99's avatar OsamaS99 Committed by GitHub
Browse files

Fixed Hybrid Cache Shape Initialization. (#32163)



* fixed hybrid cache init, added test

* Fix Test Typo

---------
Co-authored-by: default avatarAaron Haag <aaron.haag@siemens.com>
parent e3d8285a
...@@ -1813,7 +1813,9 @@ class GenerationMixin: ...@@ -1813,7 +1813,9 @@ class GenerationMixin:
) )
model_kwargs[cache_name] = self._get_cache( model_kwargs[cache_name] = self._get_cache(
generation_config.cache_implementation, generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size, getattr(generation_config, "num_beams", 1)
* getattr(generation_config, "num_return_sequences", 1)
* batch_size,
generation_config.max_length, generation_config.max_length,
model_kwargs, model_kwargs,
) )
......
...@@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase): ...@@ -292,6 +292,30 @@ class CacheIntegrationTest(unittest.TestCase):
] ]
self.assertListEqual(decoded, expected_text) self.assertListEqual(decoded, expected_text)
def test_hybrid_cache_n_sequences(self):
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-9b",
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
gen_out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=20,
num_return_sequences=2,
)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
expected_text = [
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
]
self.assertListEqual(decoded, expected_text)
@require_auto_gptq @require_auto_gptq
def test_sink_cache_hard(self): def test_sink_cache_hard(self):
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ") tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment