You need to sign in or sign up before continuing.
Unverified Commit b09912c8 authored by Yanan Xie's avatar Yanan Xie Committed by GitHub
Browse files

Fix mistral generate for long prompt / response (#27548)

* Fix mistral generate for long prompt / response

* Add unit test

* fix linter

* fix linter

* fix test

* add assisted generation test for mistral and load the model in 4 bit + fa2
parent 27b752bc
......@@ -364,7 +364,7 @@ class MistralFlashAttention2(MistralAttention):
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
slicing_tokens = kv_seq_len - self.config.sliding_window
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
......
......@@ -24,6 +24,7 @@ import pytest
from transformers import AutoTokenizer, MistralConfig, is_torch_available
from transformers.testing_utils import (
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
......@@ -494,3 +495,32 @@ class MistralIntegrationTest(unittest.TestCase):
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes
@slow
@require_flash_attn
def test_model_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
device_map="auto",
load_in_4bit=True,
use_flash_attention_2=True,
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
del model
backend_empty_cache(torch_device)
gc.collect()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment