Unverified Commit 1d2b57b8 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix CTRL tests (#17508)



* fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 693720e5
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import gc
import unittest import unittest
from transformers import CTRLConfig, is_torch_available from transformers import CTRLConfig, is_torch_available
...@@ -181,6 +182,12 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -181,6 +182,12 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
self.model_tester = CTRLModelTester(self) self.model_tester = CTRLModelTester(self)
self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37) self.config_tester = ConfigTester(self, config_class=CTRLConfig, n_embd=37)
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
...@@ -201,6 +208,12 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -201,6 +208,12 @@ class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@require_torch @require_torch
class CTRLModelLanguageGenerationTest(unittest.TestCase): class CTRLModelLanguageGenerationTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
@slow @slow
def test_lm_generate_ctrl(self): def test_lm_generate_ctrl(self):
model = CTRLLMHeadModel.from_pretrained("ctrl") model = CTRLLMHeadModel.from_pretrained("ctrl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment