Commit b670c266 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Take advantage of the cache when running tests.

Caching models across test cases and across runs of the test suite makes
slow tests somewhat more bearable.

Use gettempdir() instead of /tmp in tests. This makes it easier to
change the location of the cache with semi-standard TMPDIR/TEMP/TMP
environment variables.

Fix #2222.
parent b67fa1a8
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import XxxConfig, is_tf_available from transformers import XxxConfig, is_tf_available
...@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['xxx-base-uncased']: for model_name in ['xxx-base-uncased']:
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM, from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
...@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir) model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM, from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
...@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir) model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM, from transformers import (BertConfig, BertModel, BertForMaskedLM,
...@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -30,7 +30,7 @@ import logging ...@@ -30,7 +30,7 @@ import logging
from transformers import is_torch_available from transformers import is_torch_available
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -753,10 +753,8 @@ class CommonTestCases: ...@@ -753,10 +753,8 @@ class CommonTestCases:
[[], []]) [[], []])
def create_and_check_model_from_pretrained(self): def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]: for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model) self.parent.assertIsNotNone(model)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -16,7 +16,6 @@ from __future__ import division ...@@ -16,7 +16,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import pdb import pdb
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +26,7 @@ if is_torch_available(): ...@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): ...@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir) model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -27,7 +27,7 @@ if is_torch_available(): ...@@ -27,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +26,7 @@ if is_torch_available(): ...@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): ...@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir) model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +26,7 @@ if is_torch_available(): ...@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): ...@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -29,7 +28,7 @@ if is_torch_available(): ...@@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): ...@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir) model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (T5Config, T5Model, T5WithLMHeadModel) from transformers import (T5Config, T5Model, T5WithLMHeadModel)
...@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir) model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import AlbertConfig, is_tf_available from transformers import AlbertConfig, is_tf_available
...@@ -217,12 +216,9 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -217,12 +216,9 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['albert-base-uncased']: for model_name in ['albert-base-uncased']:
model = TFAlbertModel.from_pretrained( model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModel.from_pretrained(model_name, force_download=True) model = TFAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertModel) self.assertIsInstance(model, TFBertModel)
...@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) model = TFAutoModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForMaskedLM) self.assertIsInstance(model, TFBertForMaskedLM)
...@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForSequenceClassification) self.assertIsInstance(model, TFBertForSequenceClassification)
...@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering) self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self): def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM) self.assertIsInstance(model, TFBertForMaskedLM)
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import BertConfig, is_tf_available from transformers import BertConfig, is_tf_available
...@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import CTRLConfig, is_tf_available from transformers import CTRLConfig, is_tf_available
...@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import DistilBertConfig, is_tf_available from transformers import DistilBertConfig, is_tf_available
...@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import GPT2Config, is_tf_available from transformers import GPT2Config, is_tf_available
...@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir) model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import OpenAIGPTConfig, is_tf_available from transformers import OpenAIGPTConfig, is_tf_available
...@@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -219,10 +218,8 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,11 +17,10 @@ from __future__ import division ...@@ -17,11 +17,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import RobertaConfig, is_tf_available from transformers import RobertaConfig, is_tf_available
...@@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -192,10 +191,8 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFRobertaModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFRobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import T5Config, is_tf_available from transformers import T5Config, is_tf_available
...@@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -162,10 +161,8 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['t5-small']: for model_name in ['t5-small']:
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir) model = TFT5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment