Unverified Commit 8a1a23ae authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix GPU OOM for `mistral.py::Mask4DTestHard` (#31212)



* build

* build

* build

* build

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent df5abae8
......@@ -734,15 +734,24 @@ class MistralIntegrationTest(unittest.TestCase):
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
model_name = "mistralai/Mistral-7B-v0.1"
_model = None
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@property
def model(self):
if self.__class__._model is None:
self.__class__._model = MistralForCausalLM.from_pretrained(
self.model_name, torch_dtype=self.model_dtype
).to(torch_device)
return self.__class__._model
def setUp(self):
model_name = "mistralai/Mistral-7B-v0.1"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
self.model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
def get_test_data(self):
template = "my favorite {}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment