Unverified Commit df848acc authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `test_compile_static_cache` (#30991)



* fix

* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 70c87138
...@@ -729,11 +729,8 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -729,11 +729,8 @@ class LlamaIntegrationTest(unittest.TestCase):
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
], ],
7: [ 7: [
"Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory " "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe theory of relativ",
"goes that nothing travels faster than light, but the faster you go, the slower everything else will " "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
"be.\nThe theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, "
"and even on a good old fashioned cheeseburger. I love it on everything. I love it so",
], ],
9: [ 9: [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial" "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial"
......
...@@ -27,6 +27,7 @@ from transformers.testing_utils import ( ...@@ -27,6 +27,7 @@ from transformers.testing_utils import (
is_flaky, is_flaky,
require_bitsandbytes, require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_read_token,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_sdpa, require_torch_sdpa,
...@@ -658,12 +659,16 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -658,12 +659,16 @@ class MistralIntegrationTest(unittest.TestCase):
gc.collect() gc.collect()
@slow @slow
@require_read_token
def test_compile_static_cache(self): def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943 # work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"): if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.") self.skipTest("This test requires torch >= 2.3 to run.")
if self.cuda_compute_capability_major_version == 7:
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU.")
NUM_TOKENS_TO_GENERATE = 40 NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = { EXPECTED_TEXT_COMPLETION = {
8: [ 8: [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment