Unverified Commit 32adbb26 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix PyTorch RAG tests GPU OOM (#16881)



* add torch.cuda.empty_cache in some PT RAG tests

* torch.cuda.empty_cache in tearDownModule()

* tearDown()

* add gc.collect()
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3e47d19c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import gc
import json import json
import os import os
import shutil import shutil
...@@ -195,6 +196,10 @@ class RagTestMixin: ...@@ -195,6 +196,10 @@ class RagTestMixin:
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_retriever(self, config): def get_retriever(self, config):
dataset = Dataset.from_dict( dataset = Dataset.from_dict(
{ {
...@@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): ...@@ -677,6 +682,12 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@require_tokenizers @require_tokenizers
@require_torch_non_multi_gpu @require_torch_non_multi_gpu
class RagModelIntegrationTests(unittest.TestCase): class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@cached_property @cached_property
def sequence_model(self): def sequence_model(self):
return ( return (
...@@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -1024,6 +1035,12 @@ class RagModelIntegrationTests(unittest.TestCase):
@require_torch @require_torch
@require_retrieval @require_retrieval
class RagModelSaveLoadTests(unittest.TestCase): class RagModelSaveLoadTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def get_rag_config(self): def get_rag_config(self):
question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base") question_encoder_config = AutoConfig.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn") generator_config = AutoConfig.from_pretrained("facebook/bart-large-cnn")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment