Unverified Commit 999981da authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Tests: remove cuda versions when the result is the same 🧹🧹 (#31955)

remove cuda versions when the result is the same
parent 693cb828
...@@ -566,24 +566,10 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -566,24 +566,10 @@ class GemmaIntegrationTest(unittest.TestCase):
def test_model_2b_bf16(self): def test_model_2b_bf16(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. EXPECTED_TEXTS = [
# "Hello I am doing a project on the 1990s and I need to know what the most popular music",
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, "Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
# considering differences in hardware processing and potential deviations in generated text. ]
EXPECTED_TEXTS = {
7: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
],
8: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
],
9: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
],
}
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device torch_device
...@@ -595,30 +581,16 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -595,30 +581,16 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token @require_read_token
def test_model_2b_eager(self): def test_model_2b_eager(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. EXPECTED_TEXTS = [
# "Hello I am doing a project on the 1990s and I need to know what the most popular music",
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, "Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
# considering differences in hardware processing and potential deviations in generated text. ]
EXPECTED_TEXTS = {
7: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
],
8: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
],
9: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
],
}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager" model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
...@@ -631,31 +603,17 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -631,31 +603,17 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) self.assertEqual(output_text, EXPECTED_TEXTS)
@require_torch_sdpa @require_torch_sdpa
@require_read_token @require_read_token
def test_model_2b_sdpa(self): def test_model_2b_sdpa(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. EXPECTED_TEXTS = [
# "Hello I am doing a project on the 1990s and I need to know what the most popular music",
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, "Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
# considering differences in hardware processing and potential deviations in generated text. ]
EXPECTED_TEXTS = {
7: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
],
8: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
],
9: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
],
}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa" model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
...@@ -668,7 +626,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -668,7 +626,7 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) self.assertEqual(output_text, EXPECTED_TEXTS)
@pytest.mark.flash_attn_test @pytest.mark.flash_attn_test
@require_flash_attn @require_flash_attn
...@@ -734,7 +692,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -734,7 +692,7 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token @require_read_token
def test_model_7b_fp16(self): def test_model_7b_fp16(self):
if self.cuda_compute_capability_major_version == 7: if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
model_id = "google/gemma-7b" model_id = "google/gemma-7b"
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [
...@@ -757,7 +715,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -757,7 +715,7 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token @require_read_token
def test_model_7b_bf16(self): def test_model_7b_bf16(self):
if self.cuda_compute_capability_major_version == 7: if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
model_id = "google/gemma-7b" model_id = "google/gemma-7b"
...@@ -795,7 +753,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -795,7 +753,7 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token @require_read_token
def test_model_7b_fp16_static_cache(self): def test_model_7b_fp16_static_cache(self):
if self.cuda_compute_capability_major_version == 7: if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
model_id = "google/gemma-7b" model_id = "google/gemma-7b"
EXPECTED_TEXTS = [ EXPECTED_TEXTS = [
...@@ -821,16 +779,10 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -821,16 +779,10 @@ class GemmaIntegrationTest(unittest.TestCase):
@require_read_token @require_read_token
def test_model_7b_4bit(self): def test_model_7b_4bit(self):
model_id = "google/gemma-7b" model_id = "google/gemma-7b"
EXPECTED_TEXTS = { EXPECTED_TEXTS = [
7: [ "Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then", "Hi today I am going to talk about the best way to get rid of acne. miniaturing is a very",
"Hi today I am going to talk about the best way to get rid of acne. miniaturing is a very", ]
],
8: [
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
"Hi today I am going to talk about the best way to get rid of acne. miniaturing is a very",
],
}
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True) model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
...@@ -839,7 +791,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -839,7 +791,7 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True) output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) self.assertEqual(output_text, EXPECTED_TEXTS)
@slow @slow
@require_torch_gpu @require_torch_gpu
...@@ -851,27 +803,10 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -851,27 +803,10 @@ class GemmaIntegrationTest(unittest.TestCase):
self.skipTest(reason="This test requires torch >= 2.3 to run.") self.skipTest(reason="This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40 NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test EXPECTED_TEXT_COMPLETION = [
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. "Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
# "Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. ]
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXT_COMPLETION = {
8: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
7: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
9: [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
"Hi today\nI have a problem with my 2007 1.9 tdi 105bhp.\nI have a problem with the engine management light on.\nI have checked the",
],
}
prompts = ["Hello I am doing", "Hi today"] prompts = ["Hello I am doing", "Hi today"]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
...@@ -888,7 +823,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -888,7 +823,7 @@ class GemmaIntegrationTest(unittest.TestCase):
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile # Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
...@@ -896,7 +831,7 @@ class GemmaIntegrationTest(unittest.TestCase): ...@@ -896,7 +831,7 @@ class GemmaIntegrationTest(unittest.TestCase):
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
def test_model_2b_bf16_dola(self): def test_model_2b_bf16_dola(self):
model_id = "google/gemma-2b" model_id = "google/gemma-2b"
......
...@@ -738,32 +738,13 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -738,32 +738,13 @@ class LlamaIntegrationTest(unittest.TestCase):
NUM_TOKENS_TO_GENERATE = 40 NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
# EXPECTED_TEXT_COMPLETION = [
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
# "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, "theory of relativ",
# considering differences in hardware processing and potential deviations in generated text. "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
EXPECTED_TEXT_COMPLETION = { "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
8: [ ]
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
7: [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
9: [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial"
" reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs,"
" my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
],
}
expected_text_completion_idx = 8
prompts = [ prompts = [
"Simply put, the theory of relativity states that ", "Simply put, the theory of relativity states that ",
...@@ -778,16 +759,14 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -778,16 +759,14 @@ class LlamaIntegrationTest(unittest.TestCase):
# Dynamic Cache # Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual( self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
EXPECTED_TEXT_COMPLETION[expected_text_completion_idx], dynamic_text
) # Both GPU architectures have the same output
# Static Cache # Static Cache
generated_ids = model.generate( generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile # Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
...@@ -795,7 +774,7 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -795,7 +774,7 @@ class LlamaIntegrationTest(unittest.TestCase):
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow @slow
......
...@@ -538,10 +538,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -538,10 +538,7 @@ class MistralIntegrationTest(unittest.TestCase):
@slow @slow
@require_bitsandbytes @require_bitsandbytes
def test_model_7b_generation(self): def test_model_7b_generation(self):
EXPECTED_TEXT_COMPLETION = { EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,"
7: "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,",
8: "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,",
}
prompt = "My favourite condiment is " prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
...@@ -553,7 +550,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -553,7 +550,7 @@ class MistralIntegrationTest(unittest.TestCase):
# greedy generation outputs # greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow @slow
def test_model_7b_dola_generation(self): def test_model_7b_dola_generation(self):
...@@ -641,15 +638,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -641,15 +638,7 @@ class MistralIntegrationTest(unittest.TestCase):
@slow @slow
def test_speculative_generation(self): def test_speculative_generation(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. EXPECTED_TEXT_COMPLETION = "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXT_COMPLETION = {
7: "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big",
8: "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big",
9: "My favourite condiment is 100% ketchup. I love it on everything. I’m not a big",
}
prompt = "My favourite condiment is " prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained( model = MistralForCausalLM.from_pretrained(
...@@ -663,7 +652,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -663,7 +652,7 @@ class MistralIntegrationTest(unittest.TestCase):
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
) )
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow @slow
@require_read_token @require_read_token
...@@ -677,16 +666,10 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -677,16 +666,10 @@ class MistralIntegrationTest(unittest.TestCase):
self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
NUM_TOKENS_TO_GENERATE = 40 NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = { EXPECTED_TEXT_COMPLETION = [
8: [ "My favourite condiment is 100% ketchup. I love it on everything. "
"My favourite condiment is 100% ketchup. I love it on everything. " "I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles" ]
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"I’m not a big fan of mustard, mayo, or relish. I’m not a fan of pickles"
],
}
prompts = ["My favourite condiment is "] prompts = ["My favourite condiment is "]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False) tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
...@@ -699,21 +682,21 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -699,21 +682,21 @@ class MistralIntegrationTest(unittest.TestCase):
# Dynamic Cache # Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache # Static Cache
generated_ids = model.generate( generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Sliding Window Cache # Sliding Window Cache
generated_ids = model.generate( generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
) )
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile # Static Cache + compile
forward_function = model.forward forward_function = model.forward
...@@ -722,7 +705,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -722,7 +705,7 @@ class MistralIntegrationTest(unittest.TestCase):
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
) )
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
# Sliding Window Cache + compile # Sliding Window Cache + compile
torch._dynamo.reset() torch._dynamo.reset()
...@@ -731,7 +714,7 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -731,7 +714,7 @@ class MistralIntegrationTest(unittest.TestCase):
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window" **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
) )
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment