Commit b670c266 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Take advantage of the cache when running tests.

Caching models across test cases and across runs of the test suite makes
slow tests somewhat more bearable.

Use gettempdir() instead of /tmp in tests. This makes it easier to
change the location of the cache with semi-standard TMPDIR/TEMP/TMP
environment variables.

Fix #2222.
parent b67fa1a8
......@@ -18,11 +18,10 @@ from __future__ import print_function
import unittest
import random
import shutil
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import TransfoXLConfig, is_tf_available
......@@ -205,10 +204,8 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_tf_available
......@@ -31,7 +30,7 @@ if is_tf_available():
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
@require_tf
......@@ -252,10 +251,8 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -20,7 +20,6 @@ import os
import unittest
import json
import random
import shutil
from transformers import XLNetConfig, is_tf_available
......@@ -35,7 +34,7 @@ if is_tf_available():
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
@require_tf
......@@ -319,10 +318,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -18,7 +18,6 @@ from __future__ import print_function
import unittest
import random
import shutil
from transformers import is_torch_available
......@@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -208,10 +207,8 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -28,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -318,10 +317,8 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLMModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -20,7 +20,6 @@ import os
import unittest
import json
import random
import shutil
from transformers import is_torch_available
......@@ -33,7 +32,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -385,10 +384,8 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
import os
import unittest
import tempfile
from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment