Unverified Commit 847b47c0 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix XGLM OOM on CI (#24123)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent b8fe259f
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import annotations from __future__ import annotations
import gc
import unittest import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
...@@ -190,6 +191,11 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -190,6 +191,11 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@require_tf @require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase): class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow @slow
def test_lm_generate_xglm(self, verify_outputs=True): def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M") model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import gc
import math import math
import unittest import unittest
...@@ -349,6 +350,12 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -349,6 +350,12 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
@require_torch @require_torch
class XGLMModelLanguageGenerationTest(unittest.TestCase): class XGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _test_lm_generate_xglm_helper( def _test_lm_generate_xglm_helper(
self, self,
gradient_checkpointing=False, gradient_checkpointing=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment