Unverified Commit 847b47c0 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix XGLM OOM on CI (#24123)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent b8fe259f
......@@ -15,6 +15,7 @@
from __future__ import annotations
import gc
import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available
......@@ -190,6 +191,11 @@ class TFXGLMModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
@require_tf
class TFXGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@slow
def test_lm_generate_xglm(self, verify_outputs=True):
model = TFXGLMForCausalLM.from_pretrained("facebook/xglm-564M")
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import datetime
import gc
import math
import unittest
......@@ -349,6 +350,12 @@ class XGLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
@require_torch
class XGLMModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def _test_lm_generate_xglm_helper(
self,
gradient_checkpointing=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment