Unverified Commit 2e2088f2 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Avoid `GPT-2` daily CI job OOM (in TF tests) (#24106)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 9322c244
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import datetime import datetime
import gc
import math import math
import unittest import unittest
...@@ -500,6 +501,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -500,6 +501,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
self.model_tester = GPT2ModelTester(self) self.model_tester = GPT2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37) self.config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
...@@ -683,6 +690,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -683,6 +690,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
@require_torch @require_torch
class GPT2ModelLanguageGenerationTest(unittest.TestCase): class GPT2ModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _test_lm_generate_gpt2_helper( def _test_lm_generate_gpt2_helper(
self, self,
gradient_checkpointing=False, gradient_checkpointing=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment