Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
798b3b38
Commit
798b3b38
authored
Dec 22, 2019
by
Aymeric Augustin
Browse files
Remove sys.version_info[0] == 2 or 3.
parent
8af25b16
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
41 additions
and
170 deletions
+41
-170
examples/contrib/run_swag.py
examples/contrib/run_swag.py
+1
-7
examples/utils_multiple_choice.py
examples/utils_multiple_choice.py
+1
-8
src/transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
...s/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
+1
-6
src/transformers/data/processors/utils.py
src/transformers/data/processors/utils.py
+1
-8
src/transformers/file_utils.py
src/transformers/file_utils.py
+5
-12
src/transformers/modeling_bert.py
src/transformers/modeling_bert.py
+2
-7
src/transformers/modeling_tf_albert.py
src/transformers/modeling_tf_albert.py
+2
-7
src/transformers/modeling_tf_bert.py
src/transformers/modeling_tf_bert.py
+2
-7
src/transformers/modeling_tf_xlnet.py
src/transformers/modeling_tf_xlnet.py
+1
-4
src/transformers/modeling_xlnet.py
src/transformers/modeling_xlnet.py
+1
-4
src/transformers/tokenization_gpt2.py
src/transformers/tokenization_gpt2.py
+4
-11
src/transformers/tokenization_transfo_xl.py
src/transformers/tokenization_transfo_xl.py
+1
-6
tests/test_configuration_common.py
tests/test_configuration_common.py
+3
-4
tests/test_model_card.py
tests/test_model_card.py
+3
-4
tests/test_modeling_common.py
tests/test_modeling_common.py
+4
-23
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+3
-22
tests/test_optimization.py
tests/test_optimization.py
+2
-2
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+4
-28
No files found.
examples/contrib/run_swag.py
View file @
798b3b38
...
...
@@ -24,7 +24,6 @@ import glob
import
logging
import
os
import
random
import
sys
import
numpy
as
np
import
torch
...
...
@@ -104,12 +103,7 @@ class InputFeatures(object):
def
read_swag_examples
(
input_file
,
is_training
=
True
):
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
reader
=
csv
.
reader
(
f
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
"utf-8"
)
for
cell
in
line
)
# noqa: F821
lines
.
append
(
line
)
lines
=
list
(
csv
.
reader
(
f
))
if
is_training
and
lines
[
0
][
-
1
]
!=
"label"
:
raise
ValueError
(
"For training, the input file must contain a label column."
)
...
...
examples/utils_multiple_choice.py
View file @
798b3b38
...
...
@@ -21,7 +21,6 @@ import glob
import
json
import
logging
import
os
import
sys
from
io
import
open
from
typing
import
List
...
...
@@ -179,13 +178,7 @@ class SwagProcessor(DataProcessor):
def
_read_csv
(
self
,
input_file
):
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
reader
=
csv
.
reader
(
f
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
"utf-8"
)
for
cell
in
line
)
# noqa: F821
lines
.
append
(
line
)
return
lines
return
list
(
csv
.
reader
(
f
))
def
_create_examples
(
self
,
lines
:
List
[
List
[
str
]],
type
:
str
):
"""Creates examples for the training and dev sets."""
...
...
src/transformers/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
View file @
798b3b38
...
...
@@ -18,6 +18,7 @@
import
argparse
import
logging
import
os
import
pickle
import
sys
from
io
import
open
...
...
@@ -34,12 +35,6 @@ from transformers import (
from
transformers.tokenization_transfo_xl
import
CORPUS_NAME
,
VOCAB_FILES_NAMES
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# We do this to be able to load python 2 datasets pickles
...
...
src/transformers/data/processors/utils.py
View file @
798b3b38
...
...
@@ -18,7 +18,6 @@ import copy
import
csv
import
json
import
logging
import
sys
from
...file_utils
import
is_tf_available
,
is_torch_available
...
...
@@ -98,13 +97,7 @@ class DataProcessor(object):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
with
open
(
input_file
,
"r"
,
encoding
=
"utf-8-sig"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
"utf-8"
)
for
cell
in
line
)
# noqa: F821
lines
.
append
(
line
)
return
lines
return
list
(
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
))
class
SingleSentenceClassificationProcessor
(
DataProcessor
):
...
...
src/transformers/file_utils.py
View file @
798b3b38
...
...
@@ -166,7 +166,7 @@ def filename_to_url(filename, cache_dir=None):
"""
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
cache_path
=
os
.
path
.
join
(
cache_dir
,
filename
)
...
...
@@ -201,9 +201,9 @@ def cached_path(
"""
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
url_or_filename
,
Path
):
if
isinstance
(
url_or_filename
,
Path
):
url_or_filename
=
str
(
url_or_filename
)
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
if
is_remote_url
(
url_or_filename
):
...
...
@@ -314,9 +314,7 @@ def get_from_cache(
"""
if
cache_dir
is
None
:
cache_dir
=
TRANSFORMERS_CACHE
if
sys
.
version_info
[
0
]
==
3
and
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
if
sys
.
version_info
[
0
]
==
2
and
not
isinstance
(
cache_dir
,
str
):
if
isinstance
(
cache_dir
,
Path
):
cache_dir
=
str
(
cache_dir
)
if
not
os
.
path
.
exists
(
cache_dir
):
...
...
@@ -335,8 +333,6 @@ def get_from_cache(
except
(
EnvironmentError
,
requests
.
exceptions
.
Timeout
):
etag
=
None
if
sys
.
version_info
[
0
]
==
2
and
etag
is
not
None
:
etag
=
etag
.
decode
(
"utf-8"
)
filename
=
url_to_filename
(
url
,
etag
)
# get cache path to put the file
...
...
@@ -400,9 +396,6 @@ def get_from_cache(
meta
=
{
"url"
:
url
,
"etag"
:
etag
}
meta_path
=
cache_path
+
".json"
with
open
(
meta_path
,
"w"
)
as
meta_file
:
output_string
=
json
.
dumps
(
meta
)
if
sys
.
version_info
[
0
]
==
2
and
isinstance
(
output_string
,
str
):
output_string
=
unicode
(
output_string
,
"utf-8"
)
# noqa: F821
meta_file
.
write
(
output_string
)
json
.
dump
(
meta
,
meta_file
)
return
cache_path
src/transformers/modeling_bert.py
View file @
798b3b38
...
...
@@ -19,7 +19,6 @@
import
logging
import
math
import
os
import
sys
import
torch
from
torch
import
nn
...
...
@@ -338,9 +337,7 @@ class BertIntermediate(nn.Module):
def
__init__
(
self
,
config
):
super
(
BertIntermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
intermediate_act_fn
=
config
.
hidden_act
...
...
@@ -460,9 +457,7 @@ class BertPredictionHeadTransform(nn.Module):
def
__init__
(
self
,
config
):
super
(
BertPredictionHeadTransform
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
transform_act_fn
=
config
.
hidden_act
...
...
src/transformers/modeling_tf_albert.py
View file @
798b3b38
...
...
@@ -17,7 +17,6 @@
import
logging
import
sys
import
tensorflow
as
tf
...
...
@@ -311,9 +310,7 @@ class TFAlbertLayer(tf.keras.layers.Layer):
config
.
intermediate_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"ffn"
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
activation
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
activation
=
config
.
hidden_act
...
...
@@ -454,9 +451,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
config
.
embedding_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"dense"
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
activation
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
activation
=
config
.
hidden_act
...
...
src/transformers/modeling_tf_bert.py
View file @
798b3b38
...
...
@@ -17,7 +17,6 @@
import
logging
import
sys
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -310,9 +309,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
config
.
intermediate_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"dense"
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
intermediate_act_fn
=
config
.
hidden_act
...
...
@@ -417,9 +414,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
config
.
hidden_size
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"dense"
)
if
isinstance
(
config
.
hidden_act
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
hidden_act
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
hidden_act
,
str
):
self
.
transform_act_fn
=
ACT2FN
[
config
.
hidden_act
]
else
:
self
.
transform_act_fn
=
config
.
hidden_act
...
...
src/transformers/modeling_tf_xlnet.py
View file @
798b3b38
...
...
@@ -18,7 +18,6 @@
import
logging
import
sys
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -290,9 +289,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
config
.
d_model
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"layer_2"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout
)
if
isinstance
(
config
.
ff_activation
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
ff_activation
,
str
):
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
else
:
self
.
activation_function
=
config
.
ff_activation
...
...
src/transformers/modeling_xlnet.py
View file @
798b3b38
...
...
@@ -19,7 +19,6 @@
import
logging
import
math
import
sys
import
torch
from
torch
import
nn
...
...
@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module):
self
.
layer_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_inner
)
self
.
layer_2
=
nn
.
Linear
(
config
.
d_inner
,
config
.
d_model
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
if
isinstance
(
config
.
ff_activation
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
config
.
ff_activation
,
unicode
)
# noqa: F821
):
if
isinstance
(
config
.
ff_activation
,
str
):
self
.
activation_function
=
ACT2FN
[
config
.
ff_activation
]
else
:
self
.
activation_function
=
config
.
ff_activation
...
...
src/transformers/tokenization_gpt2.py
View file @
798b3b38
...
...
@@ -18,7 +18,6 @@
import
json
import
logging
import
os
import
sys
from
io
import
open
import
regex
as
re
...
...
@@ -80,7 +79,6 @@ def bytes_to_unicode():
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
"""
_chr
=
unichr
if
sys
.
version_info
[
0
]
==
2
else
chr
# noqa: F821
bs
=
(
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
)
...
...
@@ -91,7 +89,7 @@ def bytes_to_unicode():
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
_
chr
(
n
)
for
n
in
cs
]
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
...
...
@@ -212,11 +210,6 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
if
sys
.
version_info
[
0
]
==
2
:
token
=
""
.
join
(
self
.
byte_encoder
[
ord
(
b
)]
for
b
in
token
)
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else
:
token
=
""
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
"utf-8"
)
)
# Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
...
...
src/transformers/tokenization_transfo_xl.py
View file @
798b3b38
...
...
@@ -21,7 +21,7 @@
import
glob
import
logging
import
os
import
sys
import
pickle
from
collections
import
Counter
,
OrderedDict
from
io
import
open
...
...
@@ -36,11 +36,6 @@ try:
except
ImportError
:
pass
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
logger
=
logging
.
getLogger
(
__name__
)
...
...
tests/test_configuration_common.py
View file @
798b3b38
...
...
@@ -16,8 +16,7 @@
import
json
import
os
from
.test_tokenization_common
import
TemporaryDirectory
import
tempfile
class
ConfigTester
(
object
):
...
...
@@ -42,7 +41,7 @@ class ConfigTester(object):
def
create_and_test_config_to_json_file
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
json_file_path
=
os
.
path
.
join
(
tmpdirname
,
"config.json"
)
config_first
.
to_json_file
(
json_file_path
)
config_second
=
self
.
config_class
.
from_json_file
(
json_file_path
)
...
...
@@ -52,7 +51,7 @@ class ConfigTester(object):
def
create_and_test_config_from_and_save_pretrained
(
self
):
config_first
=
self
.
config_class
(
**
self
.
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
config_first
.
save_pretrained
(
tmpdirname
)
config_second
=
self
.
config_class
.
from_pretrained
(
tmpdirname
)
...
...
tests/test_model_card.py
View file @
798b3b38
...
...
@@ -16,12 +16,11 @@
import
json
import
os
import
tempfile
import
unittest
from
transformers.modelcard
import
ModelCard
from
.test_tokenization_common
import
TemporaryDirectory
class
ModelCardTester
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -65,7 +64,7 @@ class ModelCardTester(unittest.TestCase):
def
test_model_card_to_json_file
(
self
):
model_card_first
=
ModelCard
.
from_dict
(
self
.
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
filename
=
os
.
path
.
join
(
tmpdirname
,
"modelcard.json"
)
model_card_first
.
to_json_file
(
filename
)
model_card_second
=
ModelCard
.
from_json_file
(
filename
)
...
...
@@ -75,7 +74,7 @@ class ModelCardTester(unittest.TestCase):
def
test_model_card_from_and_save_pretrained
(
self
):
model_card_first
=
ModelCard
.
from_dict
(
self
.
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model_card_first
.
save_pretrained
(
tmpdirname
)
model_card_second
=
ModelCard
.
from_pretrained
(
tmpdirname
)
...
...
tests/test_modeling_common.py
View file @
798b3b38
...
...
@@ -19,8 +19,6 @@ import json
import
logging
import
os.path
import
random
import
shutil
import
sys
import
tempfile
import
unittest
import
uuid
...
...
@@ -43,23 +41,6 @@ if is_torch_available():
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
)
if
sys
.
version_info
[
0
]
==
2
:
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
def
_config_zero_init
(
config
):
configs_no_init
=
copy
.
deepcopy
(
config
)
...
...
@@ -92,7 +73,7 @@ class ModelTesterMixin:
out_2
=
outputs
[
0
].
numpy
()
out_2
[
np
.
isnan
(
out_2
)]
=
0
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
model
.
to
(
torch_device
)
...
...
@@ -238,7 +219,7 @@ class ModelTesterMixin:
except
RuntimeError
:
self
.
fail
(
"Couldn't trace module."
)
with
TemporaryDirectory
()
as
tmp_dir_name
:
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir_name
:
pt_file_name
=
os
.
path
.
join
(
tmp_dir_name
,
"traced_model.pt"
)
try
:
...
...
@@ -366,7 +347,7 @@ class ModelTesterMixin:
heads_to_prune
=
{
0
:
list
(
range
(
1
,
self
.
model_tester
.
num_attention_heads
)),
-
1
:
[
0
]}
model
.
prune_heads
(
heads_to_prune
)
with
TemporaryDirectory
()
as
temp_dir_name
:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir_name
:
model
.
save_pretrained
(
temp_dir_name
)
model
=
model_class
.
from_pretrained
(
temp_dir_name
)
model
.
to
(
torch_device
)
...
...
@@ -435,7 +416,7 @@ class ModelTesterMixin:
self
.
assertEqual
(
attentions
[
2
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
self
.
assertEqual
(
attentions
[
3
].
shape
[
-
3
],
self
.
model_tester
.
num_attention_heads
)
with
TemporaryDirectory
()
as
temp_dir_name
:
with
tempfile
.
TemporaryDirectory
()
as
temp_dir_name
:
model
.
save_pretrained
(
temp_dir_name
)
model
=
model_class
.
from_pretrained
(
temp_dir_name
)
model
.
to
(
torch_device
)
...
...
tests/test_modeling_tf_common.py
View file @
798b3b38
...
...
@@ -17,8 +17,6 @@
import
copy
import
os
import
random
import
shutil
import
sys
import
tempfile
from
transformers
import
is_tf_available
,
is_torch_available
...
...
@@ -32,23 +30,6 @@ if is_tf_available():
# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
if
sys
.
version_info
[
0
]
==
2
:
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
def
_config_zero_init
(
config
):
configs_no_init
=
copy
.
deepcopy
(
config
)
...
...
@@ -87,7 +68,7 @@ class TFModelTesterMixin:
model
=
model_class
(
config
)
outputs
=
model
(
inputs_dict
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
after_outputs
=
model
(
inputs_dict
)
...
...
@@ -137,7 +118,7 @@ class TFModelTesterMixin:
self
.
assertLessEqual
(
max_diff
,
2e-2
)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
pt_checkpoint_path
=
os
.
path
.
join
(
tmpdirname
,
"pt_model.bin"
)
torch
.
save
(
pt_model
.
state_dict
(),
pt_checkpoint_path
)
tf_model
=
transformers
.
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pt_checkpoint_path
)
...
...
@@ -180,7 +161,7 @@ class TFModelTesterMixin:
model
=
model_class
(
config
)
# Let's load it from the disk to be sure we can use pretrained weights
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
outputs
=
model
(
inputs_dict
)
# build the model
model
.
save_pretrained
(
tmpdirname
)
model
=
model_class
.
from_pretrained
(
tmpdirname
)
...
...
tests/test_optimization.py
View file @
798b3b38
...
...
@@ -15,11 +15,11 @@
import
os
import
tempfile
import
unittest
from
transformers
import
is_torch_available
from
.test_tokenization_common
import
TemporaryDirectory
from
.utils
import
require_torch
...
...
@@ -50,7 +50,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
scheduler
.
step
()
lrs
.
append
(
scheduler
.
get_lr
())
if
step
==
num_steps
//
2
:
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
file_name
=
os
.
path
.
join
(
tmpdirname
,
"schedule.bin"
)
torch
.
save
(
scheduler
.
state_dict
(),
file_name
)
...
...
tests/test_tokenization_common.py
View file @
798b3b38
...
...
@@ -15,33 +15,12 @@
import
os
import
pickle
import
shutil
import
sys
import
tempfile
from
io
import
open
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
import
pickle
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
class
TokenizerTesterMixin
:
tokenizer_class
=
None
...
...
@@ -90,7 +69,7 @@ class TokenizerTesterMixin:
before_tokens
=
tokenizer
.
encode
(
"He is very happy, UNwant
\u00E9
d,running"
,
add_special_tokens
=
False
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
tokenizer
.
save_pretrained
(
tmpdirname
)
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
tmpdirname
)
...
...
@@ -108,7 +87,7 @@ class TokenizerTesterMixin:
text
=
"Munich and Berlin are nice cities"
subwords
=
tokenizer
.
tokenize
(
text
)
with
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
filename
=
os
.
path
.
join
(
tmpdirname
,
"tokenizer.bin"
)
with
open
(
filename
,
"wb"
)
as
handle
:
...
...
@@ -246,7 +225,7 @@ class TokenizerTesterMixin:
self
.
assertEqual
(
text_2
,
output_text
)
self
.
assertNotEqual
(
len
(
tokens_2
),
0
)
self
.
assertIsInstance
(
text_2
,
(
str
,
unicode
)
)
self
.
assertIsInstance
(
text_2
,
str
)
def
test_encode_decode_with_spaces
(
self
):
tokenizer
=
self
.
get_tokenizer
()
...
...
@@ -268,9 +247,6 @@ class TokenizerTesterMixin:
self
.
assertListEqual
(
weights_list
,
weights_list_2
)
def
test_mask_output
(
self
):
if
sys
.
version_info
<=
(
3
,
0
):
return
tokenizer
=
self
.
get_tokenizer
()
if
tokenizer
.
build_inputs_with_special_tokens
.
__qualname__
.
split
(
"."
)[
0
]
!=
"PreTrainedTokenizer"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment