Unverified Commit 707023d1 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TF Rag OOM issue (#24122)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f2b91835
from __future__ import annotations
import gc
import json
import os
import shutil
......@@ -550,6 +551,11 @@ class TFRagDPRBartTest(TFRagTestMixin, unittest.TestCase):
@require_sentencepiece
@require_tokenizers
class TFRagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
@cached_property
def token_model(self):
return TFRagTokenForGeneration.from_pretrained_question_encoder_generator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment