Unverified Commit 8a1a23ae authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix GPU OOM for `mistral.py::Mask4DTestHard` (#31212)



* build

* build

* build

* build

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent df5abae8
...@@ -734,15 +734,24 @@ class MistralIntegrationTest(unittest.TestCase): ...@@ -734,15 +734,24 @@ class MistralIntegrationTest(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_gpu
class Mask4DTestHard(unittest.TestCase): class Mask4DTestHard(unittest.TestCase):
model_name = "mistralai/Mistral-7B-v0.1"
_model = None
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@property
def model(self):
if self.__class__._model is None:
self.__class__._model = MistralForCausalLM.from_pretrained(
self.model_name, torch_dtype=self.model_dtype
).to(torch_device)
return self.__class__._model
def setUp(self): def setUp(self):
model_name = "mistralai/Mistral-7B-v0.1" self.model_dtype = torch.float16
self.model_dtype = torch.float32 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self): def get_test_data(self):
template = "my favorite {}" template = "my favorite {}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment