Unverified Commit 32adbb26 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix PyTorch RAG tests GPU OOM (#16881)



* add torch.cuda.empty_cache in some PT RAG tests

* torch.cuda.empty_cache in tearDownModule()

* tearDown()

* add gc.collect()
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3e47d19c
......@@ -14,6 +14,7 @@
# limitations under the License.
import gc
import json
import os
import shutil
......@@ -195,6 +196,10 @@ class RagTestMixin:
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_retriever(self, config):
dataset = Dataset.from_dict(
{
......@@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@require_tokenizers
@require_torch_non_multi_gpu
class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@cached_property
def sequence_model(self):
return (
......@@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase):
@require_torch
@require_retrieval
class RagModelSaveLoadTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment