"examples/vscode:/vscode.git/clone" did not exist on "f1a4e06f1fe2baaf85799db2b0316991ee1a2405"
Commit b670c266 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Take advantage of the cache when running tests.

Caching models across test cases and across runs of the test suite makes
slow tests somewhat more bearable.

Use gettempdir() instead of /tmp in tests. This makes it easier to
change the location of the cache with semi-standard TMPDIR/TEMP/TMP
environment variables.

Fix #2222.
parent b67fa1a8
...@@ -18,11 +18,10 @@ from __future__ import print_function ...@@ -18,11 +18,10 @@ from __future__ import print_function
import unittest import unittest
import random import random
import shutil
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import TransfoXLConfig, is_tf_available from transformers import TransfoXLConfig, is_tf_available
...@@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_tf_available from transformers import is_tf_available
...@@ -31,7 +30,7 @@ if is_tf_available(): ...@@ -31,7 +30,7 @@ if is_tf_available():
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
@require_tf @require_tf
...@@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -20,7 +20,6 @@ import os ...@@ -20,7 +20,6 @@ import os
import unittest import unittest
import json import json
import random import random
import shutil
from transformers import XLNetConfig, is_tf_available from transformers import XLNetConfig, is_tf_available
...@@ -35,7 +34,7 @@ if is_tf_available(): ...@@ -35,7 +34,7 @@ if is_tf_available():
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
@require_tf @require_tf
...@@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import unittest import unittest
import random import random
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -29,7 +28,7 @@ if is_torch_available(): ...@@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester): ...@@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -28,7 +27,7 @@ if is_torch_available(): ...@@ -28,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -20,7 +20,6 @@ import os ...@@ -20,7 +20,6 @@ import os
import unittest import unittest
import json import json
import random import random
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -33,7 +32,7 @@ if is_torch_available(): ...@@ -33,7 +32,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): ...@@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir) model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
import os import os
import unittest import unittest
import tempfile
from distutils.util import strtobool from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available from transformers.file_utils import _tf_available, _torch_available
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment