Commit b670c266 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Take advantage of the cache when running tests.

Caching models across test cases and across runs of the test suite makes
slow tests somewhat more bearable.

Use gettempdir() instead of /tmp in tests. This makes it easier to
change the location of the cache with semi-standard TMPDIR/TEMP/TMP
environment variables.

Fix #2222.
parent b67fa1a8
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import XxxConfig, is_tf_available
......@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['xxx-base-uncased']:
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
......@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
......@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM,
......@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -30,7 +30,7 @@ import logging
from transformers import is_torch_available
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
import torch
......@@ -753,10 +753,8 @@ class CommonTestCases:
[[], []])
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.parent.assertIsNotNone(model)
def prepare_config_and_inputs_for_common(self):
......
......@@ -16,7 +16,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pdb
from transformers import is_torch_available
......@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -27,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
......@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import AlbertConfig, is_tf_available
......@@ -217,12 +216,9 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['albert-base-uncased']:
model = TFAlbertModel.from_pretrained(
model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModel.from_pretrained(model_name, force_download=True)
model = TFAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertModel)
......@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True)
model = TFAutoModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForMaskedLM)
......@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForSequenceClassification)
......@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import BertConfig, is_tf_available
......@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import CTRLConfig, is_tf_available
......@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -20,7 +20,7 @@ import unittest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import DistilBertConfig, is_tf_available
......@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import GPT2Config, is_tf_available
......@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import OpenAIGPTConfig, is_tf_available
......@@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,11 +17,10 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import RobertaConfig, is_tf_available
......@@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFRobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import T5Config, is_tf_available
......@@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['t5-small']:
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment