Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fae4d1c2
Unverified
Commit
fae4d1c2
authored
Dec 21, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 21, 2019
Browse files
Merge pull request #2217 from aaugustin/test-parallelization
Support running tests in parallel
parents
ac1b449c
b8e924e1
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
185 deletions
+157
-185
.circleci/config.yml
.circleci/config.yml
+39
-18
setup.py
setup.py
+1
-0
templates/adding_a_new_model/tests/modeling_tf_xxx_test.py
templates/adding_a_new_model/tests/modeling_tf_xxx_test.py
+2
-5
templates/adding_a_new_model/tests/modeling_xxx_test.py
templates/adding_a_new_model/tests/modeling_xxx_test.py
+2
-5
transformers/file_utils.py
transformers/file_utils.py
+53
-51
transformers/tests/modeling_albert_test.py
transformers/tests/modeling_albert_test.py
+2
-5
transformers/tests/modeling_bert_test.py
transformers/tests/modeling_bert_test.py
+2
-5
transformers/tests/modeling_common_test.py
transformers/tests/modeling_common_test.py
+24
-31
transformers/tests/modeling_ctrl_test.py
transformers/tests/modeling_ctrl_test.py
+2
-5
transformers/tests/modeling_distilbert_test.py
transformers/tests/modeling_distilbert_test.py
+2
-4
transformers/tests/modeling_gpt2_test.py
transformers/tests/modeling_gpt2_test.py
+2
-5
transformers/tests/modeling_openai_test.py
transformers/tests/modeling_openai_test.py
+2
-5
transformers/tests/modeling_roberta_test.py
transformers/tests/modeling_roberta_test.py
+2
-5
transformers/tests/modeling_t5_test.py
transformers/tests/modeling_t5_test.py
+2
-5
transformers/tests/modeling_tf_albert_test.py
transformers/tests/modeling_tf_albert_test.py
+3
-8
transformers/tests/modeling_tf_auto_test.py
transformers/tests/modeling_tf_auto_test.py
+9
-9
transformers/tests/modeling_tf_bert_test.py
transformers/tests/modeling_tf_bert_test.py
+2
-5
transformers/tests/modeling_tf_ctrl_test.py
transformers/tests/modeling_tf_ctrl_test.py
+2
-5
transformers/tests/modeling_tf_distilbert_test.py
transformers/tests/modeling_tf_distilbert_test.py
+2
-4
transformers/tests/modeling_tf_gpt2_test.py
transformers/tests/modeling_tf_gpt2_test.py
+2
-5
No files found.
.circleci/config.yml
View file @
fae4d1c2
version
:
2
version
:
2
jobs
:
jobs
:
build
_py3_torch_and_tf
:
run_tests
_py3_torch_and_tf
:
working_directory
:
~/transformers
working_directory
:
~/transformers
docker
:
docker
:
-
image
:
circleci/python:3.5
-
image
:
circleci/python:3.5
environment
:
OMP_NUM_THREADS
:
1
resource_class
:
xlarge
resource_class
:
xlarge
parallelism
:
1
parallelism
:
1
steps
:
steps
:
...
@@ -11,49 +13,67 @@ jobs:
...
@@ -11,49 +13,67 @@ jobs:
-
run
:
sudo pip install torch
-
run
:
sudo pip install torch
-
run
:
sudo pip install tensorflow
-
run
:
sudo pip install tensorflow
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest codecov pytest-cov
-
run
:
sudo pip install pytest codecov pytest-cov
pytest-xdist
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
python -m pytest -
s
v ./transformers/tests/ --cov
-
run
:
python -m pytest -
n 8 --dist=loadfile -s -
v ./transformers/tests/ --cov
-
run
:
codecov
-
run
:
codecov
build
_py3_torch
:
run_tests
_py3_torch
:
working_directory
:
~/transformers
working_directory
:
~/transformers
docker
:
docker
:
-
image
:
circleci/python:3.5
-
image
:
circleci/python:3.5
environment
:
OMP_NUM_THREADS
:
1
resource_class
:
xlarge
resource_class
:
xlarge
parallelism
:
1
parallelism
:
1
steps
:
steps
:
-
checkout
-
checkout
-
run
:
sudo pip install torch
-
run
:
sudo pip install torch
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest codecov pytest-cov
-
run
:
sudo pip install pytest codecov pytest-cov
pytest-xdist
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
python -m pytest -sv ./transformers/tests/ --cov
-
run
:
python -m pytest -n 8 --dist=loadfile -s -v ./transformers/tests/ --cov
-
run
:
python -m pytest -sv ./examples/
-
run
:
codecov
-
run
:
codecov
build
_py3_tf
:
run_tests
_py3_tf
:
working_directory
:
~/transformers
working_directory
:
~/transformers
docker
:
docker
:
-
image
:
circleci/python:3.5
-
image
:
circleci/python:3.5
environment
:
OMP_NUM_THREADS
:
1
resource_class
:
xlarge
resource_class
:
xlarge
parallelism
:
1
parallelism
:
1
steps
:
steps
:
-
checkout
-
checkout
-
run
:
sudo pip install tensorflow
-
run
:
sudo pip install tensorflow
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest codecov pytest-cov
-
run
:
sudo pip install pytest codecov pytest-cov
pytest-xdist
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
python -m pytest -
s
v ./transformers/tests/ --cov
-
run
:
python -m pytest -
n 8 --dist=loadfile -s -
v ./transformers/tests/ --cov
-
run
:
codecov
-
run
:
codecov
build
_py3_custom_tokenizers
:
run_tests
_py3_custom_tokenizers
:
working_directory
:
~/transformers
working_directory
:
~/transformers
docker
:
docker
:
-
image
:
circleci/python:3.5
-
image
:
circleci/python:3.5
steps
:
steps
:
-
checkout
-
checkout
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest
-
run
:
sudo pip install pytest
pytest-xdist
-
run
:
sudo pip install mecab-python3
-
run
:
sudo pip install mecab-python3
-
run
:
RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
-
run
:
RUN_CUSTOM_TOKENIZERS=1 python -m pytest -sv ./transformers/tests/tokenization_bert_japanese_test.py
run_examples_py3_torch
:
working_directory
:
~/transformers
docker
:
-
image
:
circleci/python:3.5
environment
:
OMP_NUM_THREADS
:
1
resource_class
:
xlarge
parallelism
:
1
steps
:
-
checkout
-
run
:
sudo pip install torch
-
run
:
sudo pip install --progress-bar off .
-
run
:
sudo pip install pytest pytest-xdist
-
run
:
sudo pip install tensorboardX scikit-learn
-
run
:
python -m pytest -n 8 --dist=loadfile -s -v ./examples/
deploy_doc
:
deploy_doc
:
working_directory
:
~/transformers
working_directory
:
~/transformers
docker
:
docker
:
...
@@ -66,7 +86,7 @@ jobs:
...
@@ -66,7 +86,7 @@ jobs:
-
run
:
sudo pip install --progress-bar off -r docs/requirements.txt
-
run
:
sudo pip install --progress-bar off -r docs/requirements.txt
-
run
:
sudo pip install --progress-bar off -r requirements.txt
-
run
:
sudo pip install --progress-bar off -r requirements.txt
-
run
:
./.circleci/deploy.sh
-
run
:
./.circleci/deploy.sh
repository_consistency
:
check_
repository_consistency
:
working_directory
:
~/transformers
working_directory
:
~/transformers
docker
:
docker
:
-
image
:
circleci/python:3.5
-
image
:
circleci/python:3.5
...
@@ -85,9 +105,10 @@ workflows:
...
@@ -85,9 +105,10 @@ workflows:
version
:
2
version
:
2
build_and_test
:
build_and_test
:
jobs
:
jobs
:
-
repository_consistency
-
check_repository_consistency
-
build_py3_custom_tokenizers
-
run_examples_py3_torch
-
build_py3_torch_and_tf
-
run_tests_py3_custom_tokenizers
-
build_py3_torch
-
run_tests_py3_torch_and_tf
-
build_py3_tf
-
run_tests_py3_torch
-
run_tests_py3_tf
-
deploy_doc
:
*workflow_filters
-
deploy_doc
:
*workflow_filters
setup.py
View file @
fae4d1c2
...
@@ -59,6 +59,7 @@ setup(
...
@@ -59,6 +59,7 @@ setup(
"tests.*"
,
"tests"
]),
"tests.*"
,
"tests"
]),
install_requires
=
[
'numpy'
,
install_requires
=
[
'numpy'
,
'boto3'
,
'boto3'
,
'filelock'
,
'requests'
,
'requests'
,
'tqdm'
,
'tqdm'
,
'regex != 2019.12.17'
,
'regex != 2019.12.17'
,
...
...
templates/adding_a_new_model/tests/modeling_tf_xxx_test.py
View file @
fae4d1c2
...
@@ -17,12 +17,11 @@ from __future__ import division
...
@@ -17,12 +17,11 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
import
sys
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
XxxConfig
,
is_tf_available
from
transformers
import
XxxConfig
,
is_tf_available
...
@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -245,10 +244,8 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
[
'xxx-base-uncased'
]:
for
model_name
in
[
'xxx-base-uncased'
]:
model
=
TFXxxModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFXxxModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
templates/adding_a_new_model/tests/modeling_xxx_test.py
View file @
fae4d1c2
...
@@ -17,13 +17,12 @@ from __future__ import division
...
@@ -17,13 +17,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
if
is_torch_available
():
from
transformers
import
(
XxxConfig
,
XxxModel
,
XxxForMaskedLM
,
from
transformers
import
(
XxxConfig
,
XxxModel
,
XxxForMaskedLM
,
...
@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
...
@@ -249,10 +248,8 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
XXX_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
XXX_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XxxModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
XxxModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/file_utils.py
View file @
fae4d1c2
...
@@ -10,10 +10,9 @@ import json
...
@@ -10,10 +10,9 @@ import json
import
logging
import
logging
import
os
import
os
import
six
import
six
import
shutil
import
tempfile
import
tempfile
import
fnmatch
import
fnmatch
from
functools
import
wraps
from
functools
import
partial
,
wraps
from
hashlib
import
sha256
from
hashlib
import
sha256
from
io
import
open
from
io
import
open
...
@@ -25,6 +24,8 @@ from tqdm.auto import tqdm
...
@@ -25,6 +24,8 @@ from tqdm.auto import tqdm
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
.
import
__version__
from
.
import
__version__
from
filelock
import
FileLock
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
try
:
try
:
...
@@ -334,59 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
...
@@ -334,59 +335,60 @@ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None, etag
# If we don't have a connection (etag is None) and can't identify the file
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
# try to get the last downloaded one
if
not
os
.
path
.
exists
(
cache_path
)
and
etag
is
None
:
if
not
os
.
path
.
exists
(
cache_path
)
and
etag
is
None
:
matching_files
=
fnmatch
.
filter
(
os
.
listdir
(
cache_dir
),
filename
+
'.*'
)
matching_files
=
[
matching_files
=
list
(
filter
(
lambda
s
:
not
s
.
endswith
(
'.json'
),
matching_files
))
file
for
file
in
fnmatch
.
filter
(
os
.
listdir
(
cache_dir
),
filename
+
'.*'
)
if
not
file
.
endswith
(
'.json'
)
and
not
file
.
endswith
(
'.lock'
)
]
if
matching_files
:
if
matching_files
:
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
cache_path
=
os
.
path
.
join
(
cache_dir
,
matching_files
[
-
1
])
if
resume_download
:
# Prevent parallel downloads of the same file with a lock.
incomplete_path
=
cache_path
+
'.incomplete'
lock_path
=
cache_path
+
'.lock'
@
contextmanager
with
FileLock
(
lock_path
):
def
_resumable_file_manager
():
with
open
(
incomplete_path
,
'a+b'
)
as
f
:
if
resume_download
:
yield
f
incomplete_path
=
cache_path
+
'.incomplete'
os
.
remove
(
incomplete_path
)
@
contextmanager
temp_file_manager
=
_resumable_file_manager
def
_resumable_file_manager
():
if
os
.
path
.
exists
(
incomplete_path
):
with
open
(
incomplete_path
,
'a+b'
)
as
f
:
resume_size
=
os
.
stat
(
incomplete_path
).
st_size
yield
f
temp_file_manager
=
_resumable_file_manager
if
os
.
path
.
exists
(
incomplete_path
):
resume_size
=
os
.
stat
(
incomplete_path
).
st_size
else
:
resume_size
=
0
else
:
else
:
temp_file_manager
=
partial
(
tempfile
.
NamedTemporaryFile
,
dir
=
cache_dir
,
delete
=
False
)
resume_size
=
0
resume_size
=
0
else
:
temp_file_manager
=
tempfile
.
NamedTemporaryFile
if
etag
is
not
None
and
(
not
os
.
path
.
exists
(
cache_path
)
or
force_download
):
resume_size
=
0
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
if
etag
is
not
None
and
(
not
os
.
path
.
exists
(
cache_path
)
or
force_download
):
with
temp_file_manager
()
as
temp_file
:
# Download to temporary file, then copy to cache dir once finished.
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
# Otherwise you get corrupt cache entries if the download gets interrupted.
with
temp_file_manager
()
as
temp_file
:
# GET file object
logger
.
info
(
"%s not found in cache or force_download set to True, downloading to %s"
,
url
,
temp_file
.
name
)
if
url
.
startswith
(
"s3://"
):
if
resume_download
:
# GET file object
logger
.
warn
(
'Warning: resumable downloads are not implemented for "s3://" urls'
)
if
url
.
startswith
(
"s3://"
):
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
if
resume_download
:
else
:
logger
.
warn
(
'Warning: resumable downloads are not implemented for "s3://" urls'
)
http_get
(
url
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
,
user_agent
=
user_agent
)
s3_get
(
url
,
temp_file
,
proxies
=
proxies
)
else
:
# we are copying the file before closing it, so flush to avoid truncation
http_get
(
url
,
temp_file
,
proxies
=
proxies
,
resume_size
=
resume_size
,
user_agent
=
user_agent
)
temp_file
.
flush
()
# we are copying the file before closing it, so flush to avoid truncation
logger
.
info
(
"storing %s in cache at %s"
,
url
,
cache_path
)
temp_file
.
flush
()
os
.
rename
(
temp_file
.
name
,
cache_path
)
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file
.
seek
(
0
)
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
meta
=
{
'url'
:
url
,
'etag'
:
etag
}
logger
.
info
(
"copying %s to cache at %s"
,
temp_file
.
name
,
cache_path
)
meta_path
=
cache_path
+
'.json'
with
open
(
cache_path
,
'wb'
)
as
cache_file
:
with
open
(
meta_path
,
'w'
)
as
meta_file
:
shutil
.
copyfileobj
(
temp_file
,
cache_file
)
output_string
=
json
.
dumps
(
meta
)
if
sys
.
version_info
[
0
]
==
2
and
isinstance
(
output_string
,
str
):
logger
.
info
(
"creating metadata file for %s"
,
cache_path
)
output_string
=
unicode
(
output_string
,
'utf-8'
)
# The beauty of python 2
meta
=
{
'url'
:
url
,
'etag'
:
etag
}
meta_file
.
write
(
output_string
)
meta_path
=
cache_path
+
'.json'
with
open
(
meta_path
,
'w'
)
as
meta_file
:
output_string
=
json
.
dumps
(
meta
)
if
sys
.
version_info
[
0
]
==
2
and
isinstance
(
output_string
,
str
):
output_string
=
unicode
(
output_string
,
'utf-8'
)
# The beauty of python 2
meta_file
.
write
(
output_string
)
logger
.
info
(
"removing temp file %s"
,
temp_file
.
name
)
return
cache_path
return
cache_path
transformers/tests/modeling_albert_test.py
View file @
fae4d1c2
...
@@ -17,13 +17,12 @@ from __future__ import division
...
@@ -17,13 +17,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
if
is_torch_available
():
from
transformers
import
(
AlbertConfig
,
AlbertModel
,
AlbertForMaskedLM
,
from
transformers
import
(
AlbertConfig
,
AlbertModel
,
AlbertForMaskedLM
,
...
@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
...
@@ -230,10 +229,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
AlbertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
AlbertModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/tests/modeling_bert_test.py
View file @
fae4d1c2
...
@@ -17,13 +17,12 @@ from __future__ import division
...
@@ -17,13 +17,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
,
floats_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
,
floats_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
if
is_torch_available
():
from
transformers
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
from
transformers
import
(
BertConfig
,
BertModel
,
BertForMaskedLM
,
...
@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
...
@@ -360,10 +359,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
BertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
BertModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
transformers/tests/modeling_common_test.py
View file @
fae4d1c2
...
@@ -18,7 +18,7 @@ from __future__ import print_function
...
@@ -18,7 +18,7 @@ from __future__ import print_function
import
copy
import
copy
import
sys
import
sys
import
os
import
os
.path
import
shutil
import
shutil
import
tempfile
import
tempfile
import
json
import
json
...
@@ -30,7 +30,7 @@ import logging
...
@@ -30,7 +30,7 @@ import logging
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
...
@@ -218,21 +218,22 @@ class CommonTestCases:
...
@@ -218,21 +218,22 @@ class CommonTestCases:
inputs
=
inputs_dict
[
'input_ids'
]
# Let's keep only input_ids
inputs
=
inputs_dict
[
'input_ids'
]
# Let's keep only input_ids
try
:
try
:
torch
.
jit
.
trace
(
model
,
inputs
)
traced_gpt2
=
torch
.
jit
.
trace
(
model
,
inputs
)
except
RuntimeError
:
except
RuntimeError
:
self
.
fail
(
"Couldn't trace module."
)
self
.
fail
(
"Couldn't trace module."
)
try
:
with
TemporaryDirectory
()
as
tmp_dir_name
:
traced_gpt2
=
torch
.
jit
.
trace
(
model
,
inputs
)
pt_file_name
=
os
.
path
.
join
(
tmp_dir_name
,
"traced_model.pt"
)
torch
.
jit
.
save
(
traced_gpt2
,
"traced_model.pt"
)
except
RuntimeError
:
self
.
fail
(
"Couldn't save module."
)
try
:
try
:
loaded_model
=
torch
.
jit
.
load
(
"traced_model.pt"
)
torch
.
jit
.
save
(
traced_gpt2
,
pt_file_name
)
os
.
remove
(
"traced_model.pt"
)
except
Exception
:
except
ValueError
:
self
.
fail
(
"Couldn't save module."
)
self
.
fail
(
"Couldn't load module."
)
try
:
loaded_model
=
torch
.
jit
.
load
(
pt_file_name
)
except
Exception
:
self
.
fail
(
"Couldn't load module."
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
...
@@ -352,12 +353,11 @@ class CommonTestCases:
...
@@ -352,12 +353,11 @@ class CommonTestCases:
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
model_tester
.
num_attention_heads
)),
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
model_tester
.
num_attention_heads
)),
-
1
:
[
0
]}
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
model
.
prune_heads
(
heads_to_prune
)
directory
=
"pruned_model"
if
not
os
.
path
.
exists
(
directory
):
with
TemporaryDirectory
()
as
temp_dir_name
:
os
.
makedirs
(
directory
)
model
.
save_pretrained
(
temp_dir_name
)
model
.
save_pretrained
(
directory
)
model
=
model_class
.
from_pretrained
(
temp_dir_name
)
model
=
model_class
.
from_pretrained
(
directory
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs_dict
)
outputs
=
model
(
**
inputs_dict
)
...
@@ -366,7 +366,6 @@ class CommonTestCases:
...
@@ -366,7 +366,6 @@ class CommonTestCases:
self
.
assertEqual
(
attentions
[
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
-
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
-
1
)
self
.
assertEqual
(
attentions
[
-
1
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
-
1
)
shutil
.
rmtree
(
directory
)
def
test_head_pruning_save_load_from_config_init
(
self
):
def
test_head_pruning_save_load_from_config_init
(
self
):
if
not
self
.
test_pruning
:
if
not
self
.
test_pruning
:
...
@@ -426,14 +425,10 @@ class CommonTestCases:
...
@@ -426,14 +425,10 @@ class CommonTestCases:
self
.
assertEqual
(
attentions
[
2
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
2
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
3
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
3
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
directory
=
"pruned_model"
with
TemporaryDirectory
()
as
temp_dir_name
:
model
.
save_pretrained
(
temp_dir_name
)
if
not
os
.
path
.
exists
(
directory
):
model
=
model_class
.
from_pretrained
(
temp_dir_name
)
os
.
makedirs
(
directory
)
model
.
to
(
torch_device
)
model
.
save_pretrained
(
directory
)
model
=
model_class
.
from_pretrained
(
directory
)
model
.
to
(
torch_device
)
shutil
.
rmtree
(
directory
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs_dict
)
outputs
=
model
(
**
inputs_dict
)
...
@@ -758,10 +753,8 @@ class CommonTestCases:
...
@@ -758,10 +753,8 @@ class CommonTestCases:
[[],
[]])
[[],
[]])
def
create_and_check_model_from_pretrained
(
self
):
def
create_and_check_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
self
.
base_model_class
.
pretrained_model_archive_map
.
keys
())[:
1
]:
for
model_name
in
list
(
self
.
base_model_class
.
pretrained_model_archive_map
.
keys
())[:
1
]:
model
=
self
.
base_model_class
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
self
.
base_model_class
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
parent
.
assertIsNotNone
(
model
)
self
.
parent
.
assertIsNotNone
(
model
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
...
...
transformers/tests/modeling_ctrl_test.py
View file @
fae4d1c2
...
@@ -16,7 +16,6 @@ from __future__ import division
...
@@ -16,7 +16,6 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
import
pdb
import
pdb
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
...
@@ -27,7 +26,7 @@ if is_torch_available():
...
@@ -27,7 +26,7 @@ if is_torch_available():
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
@
require_torch
...
@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
...
@@ -205,10 +204,8 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
CTRLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
CTRLModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
transformers/tests/modeling_distilbert_test.py
View file @
fae4d1c2
...
@@ -27,7 +27,7 @@ if is_torch_available():
...
@@ -27,7 +27,7 @@ if is_torch_available():
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
@
require_torch
...
@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
...
@@ -235,10 +235,8 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
# @slow
# @slow
# def test_model_from_pretrained(self):
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model)
# self.assertIsNotNone(model)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/tests/modeling_gpt2_test.py
View file @
fae4d1c2
...
@@ -17,7 +17,6 @@ from __future__ import division
...
@@ -17,7 +17,6 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
...
@@ -27,7 +26,7 @@ if is_torch_available():
...
@@ -27,7 +26,7 @@ if is_torch_available():
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
@
require_torch
...
@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -239,10 +238,8 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
GPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
GPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
transformers/tests/modeling_openai_test.py
View file @
fae4d1c2
...
@@ -17,7 +17,6 @@ from __future__ import division
...
@@ -17,7 +17,6 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
...
@@ -27,7 +26,7 @@ if is_torch_available():
...
@@ -27,7 +26,7 @@ if is_torch_available():
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
@
require_torch
...
@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
...
@@ -207,10 +206,8 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
OpenAIGPTModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
OpenAIGPTModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
transformers/tests/modeling_roberta_test.py
View file @
fae4d1c2
...
@@ -17,7 +17,6 @@ from __future__ import division
...
@@ -17,7 +17,6 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
...
@@ -29,7 +28,7 @@ if is_torch_available():
...
@@ -29,7 +28,7 @@ if is_torch_available():
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
@
require_torch
@
require_torch
...
@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
...
@@ -199,10 +198,8 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
RobertaModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
RobertaModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
transformers/tests/modeling_t5_test.py
View file @
fae4d1c2
...
@@ -17,13 +17,12 @@ from __future__ import division
...
@@ -17,13 +17,12 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
,
floats_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
,
floats_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
if
is_torch_available
():
if
is_torch_available
():
from
transformers
import
(
T5Config
,
T5Model
,
T5WithLMHeadModel
)
from
transformers
import
(
T5Config
,
T5Model
,
T5WithLMHeadModel
)
...
@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
...
@@ -175,10 +174,8 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
T5_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
T5_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
T5Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
T5Model
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/tests/modeling_tf_albert_test.py
View file @
fae4d1c2
...
@@ -17,12 +17,11 @@ from __future__ import division
...
@@ -17,12 +17,11 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
import
sys
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
AlbertConfig
,
is_tf_available
from
transformers
import
AlbertConfig
,
is_tf_available
...
@@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -217,12 +216,8 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
# for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model
=
TFAlbertModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
for
model_name
in
[
'albert-base-uncased'
]:
model
=
TFAlbertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
transformers/tests/modeling_tf_auto_test.py
View file @
fae4d1c2
...
@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
...
@@ -46,11 +46,11 @@ class TFAutoModelTest(unittest.TestCase):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for
model_name
in
[
'bert-base-uncased'
]:
for
model_name
in
[
'bert-base-uncased'
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
,
force_download
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
TFAutoModel
.
from_pretrained
(
model_name
,
force_download
=
True
)
model
=
TFAutoModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
TFBertModel
)
self
.
assertIsInstance
(
model
,
TFBertModel
)
...
@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
...
@@ -59,11 +59,11 @@ class TFAutoModelTest(unittest.TestCase):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for
model_name
in
[
'bert-base-uncased'
]:
for
model_name
in
[
'bert-base-uncased'
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
,
force_download
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
TFAutoModelWithLMHead
.
from_pretrained
(
model_name
,
force_download
=
True
)
model
=
TFAutoModelWithLMHead
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
TFBertForMaskedLM
)
self
.
assertIsInstance
(
model
,
TFBertForMaskedLM
)
...
@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
...
@@ -72,11 +72,11 @@ class TFAutoModelTest(unittest.TestCase):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for
model_name
in
[
'bert-base-uncased'
]:
for
model_name
in
[
'bert-base-uncased'
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
,
force_download
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
TFAutoModelForSequenceClassification
.
from_pretrained
(
model_name
,
force_download
=
True
)
model
=
TFAutoModelForSequenceClassification
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
TFBertForSequenceClassification
)
self
.
assertIsInstance
(
model
,
TFBertForSequenceClassification
)
...
@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
...
@@ -85,17 +85,17 @@ class TFAutoModelTest(unittest.TestCase):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for
model_name
in
[
'bert-base-uncased'
]:
for
model_name
in
[
'bert-base-uncased'
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
,
force_download
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
TFAutoModelForQuestionAnswering
.
from_pretrained
(
model_name
,
force_download
=
True
)
model
=
TFAutoModelForQuestionAnswering
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
TFBertForQuestionAnswering
)
self
.
assertIsInstance
(
model
,
TFBertForQuestionAnswering
)
def
test_from_pretrained_identifier
(
self
):
def
test_from_pretrained_identifier
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
model
=
TFAutoModelWithLMHead
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
,
force_download
=
True
)
model
=
TFAutoModelWithLMHead
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
)
self
.
assertIsInstance
(
model
,
TFBertForMaskedLM
)
self
.
assertIsInstance
(
model
,
TFBertForMaskedLM
)
...
...
transformers/tests/modeling_tf_bert_test.py
View file @
fae4d1c2
...
@@ -17,12 +17,11 @@ from __future__ import division
...
@@ -17,12 +17,11 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
import
sys
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
BertConfig
,
is_tf_available
from
transformers
import
BertConfig
,
is_tf_available
...
@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -310,11 +309,9 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for
model_name
in
[
'bert-base-uncased'
]:
for
model_name
in
[
'bert-base-uncased'
]:
model
=
TFBertModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFBertModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/tests/modeling_tf_ctrl_test.py
View file @
fae4d1c2
...
@@ -17,12 +17,11 @@ from __future__ import division
...
@@ -17,12 +17,11 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
import
sys
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
CTRLConfig
,
is_tf_available
from
transformers
import
CTRLConfig
,
is_tf_available
...
@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -189,10 +188,8 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFCTRLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFCTRLModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/tests/modeling_tf_distilbert_test.py
View file @
fae4d1c2
...
@@ -20,7 +20,7 @@ import unittest
...
@@ -20,7 +20,7 @@ import unittest
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
DistilBertConfig
,
is_tf_available
from
transformers
import
DistilBertConfig
,
is_tf_available
...
@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -211,10 +211,8 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
# @slow
# @slow
# def test_model_from_pretrained(self):
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# model = DistilBertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model)
# self.assertIsNotNone(model)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/tests/modeling_tf_gpt2_test.py
View file @
fae4d1c2
...
@@ -17,12 +17,11 @@ from __future__ import division
...
@@ -17,12 +17,11 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
unittest
import
unittest
import
shutil
import
sys
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
transformers
import
GPT2Config
,
is_tf_available
from
transformers
import
GPT2Config
,
is_tf_available
...
@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -220,10 +219,8 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment