Unverified Commit fae4d1c2 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2217 from aaugustin/test-parallelization

Support running tests in parallel
parents ac1b449c b8e924e1
version: 2 version: 2
jobs: jobs:
build_py3_torch_and_tf: run_tests_py3_torch_and_tf:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
- image: circleci/python:3.5 - image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge resource_class: xlarge
parallelism: 1 parallelism: 1
steps: steps:
...@@ -11,49 +13,67 @@ jobs: ...@@ -11,49 +13,67 @@ jobs:
- run: sudo pip install torch - run: sudo pip install torch
- run: sudo pip install tensorflow - run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov - run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn - run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov - run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov - run: codecov
build_py3_torch: run_tests_py3_torch:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
- image: circleci/python:3.5 - image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge resource_class: xlarge
parallelism: 1 parallelism: 1
steps: steps:
- checkout - checkout
- run: sudo pip install torch - run: sudo pip install torch
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov - run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn - run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov - run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: python -m pytest -sv ./examples/
- run: codecov - run: codecov
build_py3_tf: run_tests_py3_tf:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
- image: circleci/python:3.5 - image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge resource_class: xlarge
parallelism: 1 parallelism: 1
steps: steps:
- checkout - checkout
- run: sudo pip install tensorflow - run: sudo pip install tensorflow
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest codecov pytest-cov - run: sudo pip install pytest codecov pytest-cov pytest-xdist
- run: sudo pip install tensorboardX scikit-learn - run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -sv ./transformers/tests/ --cov - run: python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
- run: codecov - run: codecov
build_py3_custom_tokenizers: run_tests_py3_custom_tokenizers:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
- image: circleci/python:3.5 - image: circleci/python:3.5
steps: steps:
- checkout - checkout
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest - run: sudo pip install pytest pytest-xdist
- run: sudo pip install mecab-python3 - run: sudo pip install mecab-python3
- run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py - run: RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
run_examples_py3_torch:
working_directory: ~/transformers
docker:
- image: circleci/python:3.5
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- run: sudo pip install torch
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest pytest-xdist
- run: sudo pip install tensorboardX scikit-learn
- run: python -m pytest -n 8 --dist=loadfile -s -v ./examples/
deploy_doc: deploy_doc:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
...@@ -66,7 +86,7 @@ jobs: ...@@ -66,7 +86,7 @@ jobs:
- run: sudo pip install --progress-bar off -r docs/requirements.txt - run: sudo pip install --progress-bar off -r docs/requirements.txt
- run: sudo pip install --progress-bar off -r requirements.txt - run: sudo pip install --progress-bar off -r requirements.txt
- run: ./.circleci/deploy.sh - run: ./.circleci/deploy.sh
repository_consistency: check_repository_consistency:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
- image: circleci/python:3.5 - image: circleci/python:3.5
...@@ -85,9 +105,10 @@ workflows: ...@@ -85,9 +105,10 @@ workflows:
version: 2 version: 2
build_and_test: build_and_test:
jobs: jobs:
- repository_consistency - check_repository_consistency
- build_py3_custom_tokenizers - run_examples_py3_torch
- build_py3_torch_and_tf - run_tests_py3_custom_tokenizers
- build_py3_torch - run_tests_py3_torch_and_tf
- build_py3_tf - run_tests_py3_torch
- run_tests_py3_tf
- deploy_doc: *workflow_filters - deploy_doc: *workflow_filters
...@@ -59,6 +59,7 @@ setup( ...@@ -59,6 +59,7 @@ setup(
"tests.*", "tests"]), "tests.*", "tests"]),
install_requires=['numpy', install_requires=['numpy',
'boto3', 'boto3',
'filelock',
'requests', 'requests',
'tqdm', 'tqdm',
'regex != 2019.12.17', 'regex != 2019.12.17',
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import XxxConfig, is_tf_available from transformers import XxxConfig, is_tf_available
...@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in ['xxx-base-uncased']: for model_name in ['xxx-base-uncased']:
model = TFXxxModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM, from transformers import (XxxConfig, XxxModel, XxxForMaskedLM,
...@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = XxxModel.from_pretrained(model_name, cache_dir=cache_dir) model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,10 +10,9 @@ import json ...@@ -10,10 +10,9 @@ import json
import logging import logging
import os import os
import six import six
import shutil
import tempfile import tempfile
import fnmatch import fnmatch
from functools import wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from io import open from io import open
...@@ -25,6 +24,8 @@ from tqdm.auto import tqdm ...@@ -25,6 +24,8 @@ from tqdm.auto import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from . import __version__ from . import __version__
from filelock import FileLock
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
...@@ -334,59 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag ...@@ -334,59 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# If we don't have a connection (etag is None) and can't identify the file # If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one # try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None: if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') matching_files = [
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) file
for file in fnmatch.filter(os.listdir(cache_dir), filename + '.*')
if not file.endswith('.json') and not file.endswith('.lock')
]
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
if resume_download: # Prevent parallel downloads of the same file with a lock.
incomplete_path = cache_path + '.incomplete' lock_path = cache_path + '.lock'
@contextmanager with FileLock(lock_path):
def _resumable_file_manager():
with open(incomplete_path,'a+b') as f: if resume_download:
yield f incomplete_path = cache_path + '.incomplete'
os.remove(incomplete_path) @contextmanager
temp_file_manager = _resumable_file_manager def _resumable_file_manager():
if os.path.exists(incomplete_path): with open(incomplete_path,'a+b') as f:
resume_size = os.stat(incomplete_path).st_size yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else: else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
resume_size = 0 resume_size = 0
else:
temp_file_manager = tempfile.NamedTemporaryFile if etag is not None and (not os.path.exists(cache_path) or force_download):
resume_size = 0 # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
if etag is not None and (not os.path.exists(cache_path) or force_download): with temp_file_manager() as temp_file:
# Download to temporary file, then copy to cache dir once finished. logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file: # GET file object
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) if url.startswith("s3://"):
if resume_download:
# GET file object logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
if url.startswith("s3://"): s3_get(url, temp_file, proxies=proxies)
if resume_download: else:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
s3_get(url, temp_file, proxies=proxies)
else: # we are copying the file before closing it, so flush to avoid truncation
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) temp_file.flush()
# we are copying the file before closing it, so flush to avoid truncation logger.info("storing %s in cache at %s", url, cache_path)
temp_file.flush() os.rename(temp_file.name, cache_path)
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
logger.info("copying %s to cache at %s", temp_file.name, cache_path) meta_path = cache_path + '.json'
with open(cache_path, 'wb') as cache_file: with open(meta_path, 'w') as meta_file:
shutil.copyfileobj(temp_file, cache_file) output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str):
logger.info("creating metadata file for %s", cache_path) output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta = {'url': url, 'etag': etag} meta_file.write(output_string)
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
output_string = json.dumps(meta)
if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, 'utf-8') # The beauty of python 2
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path return cache_path
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM, from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM,
...@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): ...@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = AlbertModel.from_pretrained(model_name, cache_dir=cache_dir) model = AlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM, from transformers import (BertConfig, BertModel, BertForMaskedLM,
...@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = BertModel.from_pretrained(model_name, cache_dir=cache_dir) model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -18,7 +18,7 @@ from __future__ import print_function ...@@ -18,7 +18,7 @@ from __future__ import print_function
import copy import copy
import sys import sys
import os import os.path
import shutil import shutil
import tempfile import tempfile
import json import json
...@@ -30,7 +30,7 @@ import logging ...@@ -30,7 +30,7 @@ import logging
from transformers import is_torch_available from transformers import is_torch_available
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -218,21 +218,22 @@ class CommonTestCases: ...@@ -218,21 +218,22 @@ class CommonTestCases:
inputs = inputs_dict['input_ids'] # Let's keep only input_ids inputs = inputs_dict['input_ids'] # Let's keep only input_ids
try: try:
torch.jit.trace(model, inputs) traced_gpt2 = torch.jit.trace(model, inputs)
except RuntimeError: except RuntimeError:
self.fail("Couldn't trace module.") self.fail("Couldn't trace module.")
try: with TemporaryDirectory() as tmp_dir_name:
traced_gpt2 = torch.jit.trace(model, inputs) pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
self.fail("Couldn't save module.")
try: try:
loaded_model = torch.jit.load("traced_model.pt") torch.jit.save(traced_gpt2, pt_file_name)
os.remove("traced_model.pt") except Exception:
except ValueError: self.fail("Couldn't save module.")
self.fail("Couldn't load module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -352,12 +353,11 @@ class CommonTestCases: ...@@ -352,12 +353,11 @@ class CommonTestCases:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0]} -1: [0]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
directory = "pruned_model"
if not os.path.exists(directory): with TemporaryDirectory() as temp_dir_name:
os.makedirs(directory) model.save_pretrained(temp_dir_name)
model.save_pretrained(directory) model = model_class.from_pretrained(temp_dir_name)
model = model_class.from_pretrained(directory) model.to(torch_device)
model.to(torch_device)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
...@@ -366,7 +366,6 @@ class CommonTestCases: ...@@ -366,7 +366,6 @@ class CommonTestCases:
self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[1].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1) self.assertEqual(attentions[-1].shape[-3], self.model_tester.num_attention_heads - 1)
shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self): def test_head_pruning_save_load_from_config_init(self):
if not self.test_pruning: if not self.test_pruning:
...@@ -426,14 +425,10 @@ class CommonTestCases: ...@@ -426,14 +425,10 @@ class CommonTestCases:
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
directory = "pruned_model" with TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
if not os.path.exists(directory): model = model_class.from_pretrained(temp_dir_name)
os.makedirs(directory) model.to(torch_device)
model.save_pretrained(directory)
model = model_class.from_pretrained(directory)
model.to(torch_device)
shutil.rmtree(directory)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**inputs_dict)
...@@ -758,10 +753,8 @@ class CommonTestCases: ...@@ -758,10 +753,8 @@ class CommonTestCases:
[[], []]) [[], []])
def create_and_check_model_from_pretrained(self): def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]: for model_name in list(self.base_model_class.pretrained_model_archive_map.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir) model = self.base_model_class.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model) self.parent.assertIsNotNone(model)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -16,7 +16,6 @@ from __future__ import division ...@@ -16,7 +16,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import pdb import pdb
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +26,7 @@ if is_torch_available(): ...@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): ...@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = CTRLModel.from_pretrained(model_name, cache_dir=cache_dir) model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -27,7 +27,7 @@ if is_torch_available(): ...@@ -27,7 +27,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): ...@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +26,7 @@ if is_torch_available(): ...@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester): ...@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = GPT2Model.from_pretrained(model_name, cache_dir=cache_dir) model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -27,7 +26,7 @@ if is_torch_available(): ...@@ -27,7 +26,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): ...@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=cache_dir) model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,7 +17,6 @@ from __future__ import division ...@@ -17,7 +17,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
...@@ -29,7 +28,7 @@ if is_torch_available(): ...@@ -29,7 +28,7 @@ if is_torch_available():
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
@require_torch @require_torch
...@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): ...@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = RobertaModel.from_pretrained(model_name, cache_dir=cache_dir) model = RobertaModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -17,13 +17,12 @@ from __future__ import division ...@@ -17,13 +17,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor) from .modeling_common_test import (CommonTestCases, ids_tensor, floats_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (T5Config, T5Model, T5WithLMHeadModel) from transformers import (T5Config, T5Model, T5WithLMHeadModel)
...@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester): ...@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir) model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import AlbertConfig, is_tf_available from transformers import AlbertConfig, is_tf_available
...@@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/" for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = TFAlbertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
for model_name in ['albert-base-uncased']:
model = TFAlbertModel.from_pretrained(
model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
...@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModel.from_pretrained(model_name, force_download=True) model = TFAutoModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertModel) self.assertIsInstance(model, TFBertModel)
...@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModelWithLMHead.from_pretrained(model_name, force_download=True) model = TFAutoModelWithLMHead.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForMaskedLM) self.assertIsInstance(model, TFBertForMaskedLM)
...@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, force_download=True) model = TFAutoModelForSequenceClassification.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForSequenceClassification) self.assertIsInstance(model, TFBertForSequenceClassification)
...@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
config = AutoConfig.from_pretrained(model_name, force_download=True) config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config) self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig) self.assertIsInstance(config, BertConfig)
model = TFAutoModelForQuestionAnswering.from_pretrained(model_name, force_download=True) model = TFAutoModelForQuestionAnswering.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForQuestionAnswering) self.assertIsInstance(model, TFBertForQuestionAnswering)
def test_from_pretrained_identifier(self): def test_from_pretrained_identifier(self):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER, force_download=True) model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM) self.assertIsInstance(model, TFBertForMaskedLM)
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import BertConfig, is_tf_available from transformers import BertConfig, is_tf_available
...@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ['bert-base-uncased']: for model_name in ['bert-base-uncased']:
model = TFBertModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import CTRLConfig, is_tf_available from transformers import CTRLConfig, is_tf_available
...@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFCTRLModel.from_pretrained(model_name, cache_dir=cache_dir) model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import DistilBertConfig, is_tf_available from transformers import DistilBertConfig, is_tf_available
...@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @slow # @slow
# def test_model_from_pretrained(self): # def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: # for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir) # model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model) # self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,12 +17,11 @@ from __future__ import division ...@@ -17,12 +17,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import shutil
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import require_tf, slow from .utils import CACHE_DIR, require_tf, slow
from transformers import GPT2Config, is_tf_available from transformers import GPT2Config, is_tf_available
...@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
cache_dir = "/tmp/transformers_test/"
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFGPT2Model.from_pretrained(model_name, cache_dir=cache_dir) model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
shutil.rmtree(cache_dir)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment