Unverified Commit fae4d1c2 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2217 from aaugustin/test-parallelization

Support running tests in parallel
parents ac1b449c b8e924e1
version: 2
jobs:
build_py3_torch_and_tf:
run_tests_py3_torch_and_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
......@@ -11,49 +13,67 @@ jobs:
- run: sudo pip install torch
- run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov
build_py3_torch:
run_tests_py3_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov
build_py3_tf:
run_tests_py3_tf:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov
- run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov
build_py3_custom_tokenizers:
run_tests_py3_custom_tokenizers:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
steps:
- checkout
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest
- run: sudo pip install pytest pytest-xdist
- run: sudo pip install mecab-python3
- run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
run_examples_py3_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -n 8 --dist=loadfile -s -v ./examples/
deploy_doc:
working_directory: ~/transformers
docker:
......@@ -66,7 +86,7 @@ jobs:
- run: sudo pip install --progress-bar off -r docs/requirements.txt
- run: sudo pip install --progress-bar off -r requirements.txt
- run: ./.circleci/deploy.sh
repository_consistency:
check_repository_consistency:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
......@@ -85,9 +105,10 @@ workflows:
version: 2
build_and_test:
jobs:
- repository_consistency
- build_py3_custom_tokenizers
- build_py3_torch_and_tf
- build_py3_torch
- build_py3_tf
- check_repository_consistency
- run_examples_py3_torch
- run_tests_py3_custom_tokenizers
- run_tests_py3_torch_and_tf
- run_tests_py3_torch
- run_tests_py3_tf
- deploy_doc: *workflow_filters
......@@ -59,6 +59,7 @@ setup(
"tests.*", "tests"]),
install_requires=['numpy',
'boto3',
'filelock',
'requests',
'tqdm',
'regex != 2019.12.17',
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import XxxConfig, is_tf_available
......@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['xxx-base-uncased']:
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
......@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -10,10 +10,9 @@ import json
import logging
import os
import six
import shutil
import tempfile
import fnmatch
from functools import wraps
from functools import partial, wraps
from hashlib import sha256
from io import open
......@@ -25,6 +24,8 @@ from tqdm.auto import tqdm
from contextlib import contextmanager
from . import __version__
from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try:
......@@ -334,25 +335,31 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
if not file.endswith('.json') and not file.endswith('.lock')
]
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock'
with FileLock(lock_path):
if resume_download:
incomplete_path = cache_path + '.incomplete'
@contextmanager
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f:
yield f
os.remove(incomplete_path)
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = tempfile.NamedTemporaryFile
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0
if etag is not None and (not os.path.exists(cache_path) or force_download):
......@@ -371,12 +378,9 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("storing %s in cache at %s", url, cache_path)
os.rename(temp_file.name, cache_path)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
......@@ -387,6 +391,4 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
......@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM,
......@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -18,7 +18,7 @@ from __future__ import print_function
import copy
import sys
import os
import os.path
import shutil
import tempfile
import json
......@@ -30,7 +30,7 @@ import logging
from transformers import is_torch_available
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
import torch
......@@ -218,20 +218,21 @@ class CommonTestCases:
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
try:
torch.jit.trace(model, inputs)
traced_gpt2 = torch.jit.trace(model, inputs)
except RuntimeError:
self.fail("Couldn't trace module.")
with TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
traced_gpt2 = torch.jit.trace(model, inputs)
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
torch.jit.save(traced_gpt2, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load("traced_model.pt")
os.remove("traced_model.pt")
except ValueError:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
......@@ -352,11 +353,10 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
directory = "pruned_model"
if not os.path.exists(directory):
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
with TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device)
with torch.no_grad():
......@@ -366,7 +366,6 @@ class CommonTestCases:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self):
if not self.test_pruning:
......@@ -426,14 +425,10 @@ class CommonTestCases:
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
directory = "pruned_model"
if not os.path.exists(directory):
os.makedirs(directory)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
with TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device)
shutil.rmtree(directory)
with torch.no_grad():
outputs = model(**inputs_dict)
......@@ -758,10 +753,8 @@ class CommonTestCases:
[[], []])
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.parent.assertIsNotNone(model)
def prepare_config_and_inputs_for_common(self):
......
......@@ -16,7 +16,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import pdb
from transformers import is_torch_available
......@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -27,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
......@@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch
......@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device
from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
......@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import AlbertConfig, is_tf_available
......@@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['albert-base-uncased']:
model = TFAlbertModel.from_pretrained(
model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
......
......@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModel.from_pretrained(model_name, force_download=True)
model = TFAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertModel)
......@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True)
model = TFAutoModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForMaskedLM)
......@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForSequenceClassification)
......@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True)
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import BertConfig, is_tf_available
......@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import CTRLConfig, is_tf_available
......@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -20,7 +20,7 @@ import unittest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import DistilBertConfig, is_tf_available
......@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# self.assertIsNotNone(model)
if __name__ == "__main__":
......
......@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function
import unittest
import shutil
import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester
from .utils import require_tf, slow
from .utils import CACHE_DIR, require_tf, slow
from transformers import GPT2Config, is_tf_available
......@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow
def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment