Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c079d7dd
Commit
c079d7dd
authored
Jul 09, 2019
by
thomwolf
Browse files
fix python 2 tests
parent
b1978698
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
39 additions
and
37 deletions
+39
-37
pytorch_transformers/tests/tokenization_bert_test.py
pytorch_transformers/tests/tokenization_bert_test.py
+10
-13
pytorch_transformers/tests/tokenization_gpt2_test.py
pytorch_transformers/tests/tokenization_gpt2_test.py
+2
-3
pytorch_transformers/tests/tokenization_openai_test.py
pytorch_transformers/tests/tokenization_openai_test.py
+2
-3
pytorch_transformers/tests/tokenization_tests_commons.py
pytorch_transformers/tests/tokenization_tests_commons.py
+12
-5
pytorch_transformers/tests/tokenization_transfo_xl_test.py
pytorch_transformers/tests/tokenization_transfo_xl_test.py
+2
-3
pytorch_transformers/tests/tokenization_xlm_test.py
pytorch_transformers/tests/tokenization_xlm_test.py
+2
-3
pytorch_transformers/tests/tokenization_xlnet_test.py
pytorch_transformers/tests/tokenization_xlnet_test.py
+3
-4
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+6
-3
No files found.
pytorch_transformers/tests/tokenization_bert_test.py
View file @
c079d7dd
...
@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
...
@@ -24,7 +24,7 @@ from pytorch_transformers.tokenization_bert import (BasicTokenizer,
_is_control
,
_is_punctuation
,
_is_control
,
_is_punctuation
,
_is_whitespace
,
VOCAB_FILES_NAMES
)
_is_whitespace
,
VOCAB_FILES_NAMES
)
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
TokenizationTest
(
unittest
.
TestCase
):
class
TokenizationTest
(
unittest
.
TestCase
):
...
@@ -33,13 +33,12 @@ class TokenizationTest(unittest.TestCase):
...
@@ -33,13 +33,12 @@ class TokenizationTest(unittest.TestCase):
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
,
"low"
,
"lowest"
,
"##ing"
,
","
,
"low"
,
"lowest"
,
]
]
vocab_directory
=
"/tmp/"
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_file
=
os
.
path
.
join
(
vocab_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
with
open
(
vocab_file
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
with
open
(
vocab_file
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
create_and_check_tokenizer_commons
(
self
,
BertTokenizer
,
pretrained_model_name_or_path
=
vocab_directory
)
create_and_check_tokenizer_commons
(
self
,
BertTokenizer
,
tmpdirname
)
tokenizer
=
BertTokenizer
(
vocab_file
)
tokenizer
=
BertTokenizer
(
vocab_file
)
...
@@ -47,8 +46,6 @@ class TokenizationTest(unittest.TestCase):
...
@@ -47,8 +46,6 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
os
.
remove
(
vocab_file
)
def
test_chinese
(
self
):
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
tokenizer
=
BasicTokenizer
()
...
...
pytorch_transformers/tests/tokenization_gpt2_test.py
View file @
c079d7dd
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
json
import
json
import
tempfile
from
pytorch_transformers.tokenization_gpt2
import
GPT2Tokenizer
,
VOCAB_FILES_NAMES
from
pytorch_transformers.tokenization_gpt2
import
GPT2Tokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
...
@@ -34,7 +33,7 @@ class GPT2TokenizationTest(unittest.TestCase):
...
@@ -34,7 +33,7 @@ class GPT2TokenizationTest(unittest.TestCase):
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
"w"
)
as
fp
:
with
open
(
vocab_file
,
"w"
)
as
fp
:
...
...
pytorch_transformers/tests/tokenization_openai_test.py
View file @
c079d7dd
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
json
import
json
import
tempfile
from
pytorch_transformers.tokenization_openai
import
OpenAIGPTTokenizer
,
VOCAB_FILES_NAMES
from
pytorch_transformers.tokenization_openai
import
OpenAIGPTTokenizer
,
VOCAB_FILES_NAMES
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
OpenAIGPTTokenizationTest
(
unittest
.
TestCase
):
class
OpenAIGPTTokenizationTest
(
unittest
.
TestCase
):
...
@@ -35,7 +34,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
...
@@ -35,7 +34,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
"w"
)
as
fp
:
with
open
(
vocab_file
,
"w"
)
as
fp
:
...
...
pytorch_transformers/tests/tokenization_tests_commons.py
View file @
c079d7dd
...
@@ -14,18 +14,25 @@
...
@@ -14,18 +14,25 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
sys
import
sys
from
io
import
open
from
io
import
open
import
tempfile
import
tempfile
import
shutil
if
sys
.
version_info
[
0
]
==
3
:
unicode
=
str
if
sys
.
version_info
[
0
]
==
2
:
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
import
cPickle
as
pickle
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
else
:
import
pickle
import
pickle
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
def
create_and_check_save_and_load_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
def
create_and_check_save_and_load_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
...
@@ -33,7 +40,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, *
...
@@ -33,7 +40,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, *
before_tokens
=
tokenizer
.
encode
(
u
"He is very happy, UNwant
\u00E9
d,running"
)
before_tokens
=
tokenizer
.
encode
(
u
"He is very happy, UNwant
\u00E9
d,running"
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
tokenizer
.
save_pretrained
(
tmpdirname
)
tokenizer
.
save_pretrained
(
tmpdirname
)
tokenizer
=
tokenizer
.
from_pretrained
(
tmpdirname
)
tokenizer
=
tokenizer
.
from_pretrained
(
tmpdirname
)
...
...
pytorch_transformers/tests/tokenization_transfo_xl_test.py
View file @
c079d7dd
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
from
io
import
open
from
io
import
open
import
tempfile
from
pytorch_transformers.tokenization_transfo_xl
import
TransfoXLTokenizer
,
VOCAB_FILES_NAMES
from
pytorch_transformers.tokenization_transfo_xl
import
TransfoXLTokenizer
,
VOCAB_FILES_NAMES
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
...
@@ -30,7 +29,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -30,7 +29,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
"<unk>"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"unwanted"
,
"wa"
,
"un"
,
"<unk>"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"unwanted"
,
"wa"
,
"un"
,
"running"
,
","
,
"low"
,
"l"
,
"running"
,
","
,
"low"
,
"l"
,
]
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
with
open
(
vocab_file
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
with
open
(
vocab_file
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
...
...
pytorch_transformers/tests/tokenization_xlm_test.py
View file @
c079d7dd
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
json
import
json
import
tempfile
from
pytorch_transformers.tokenization_xlm
import
XLMTokenizer
,
VOCAB_FILES_NAMES
from
pytorch_transformers.tokenization_xlm
import
XLMTokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
XLMTokenizationTest
(
unittest
.
TestCase
):
class
XLMTokenizationTest
(
unittest
.
TestCase
):
...
@@ -34,7 +33,7 @@ class XLMTokenizationTest(unittest.TestCase):
...
@@ -34,7 +33,7 @@ class XLMTokenizationTest(unittest.TestCase):
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"l o 123"
,
"lo w 1456"
,
"e r</w> 1789"
,
""
]
merges
=
[
"l o 123"
,
"lo w 1456"
,
"e r</w> 1789"
,
""
]
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
"w"
)
as
fp
:
with
open
(
vocab_file
,
"w"
)
as
fp
:
...
...
pytorch_transformers/tests/tokenization_xlnet_test.py
View file @
c079d7dd
...
@@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -16,11 +16,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
tempfile
from
pytorch_transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
SPIECE_UNDERLINE
,
VOCAB_FILES_NAMES
)
from
pytorch_transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
SPIECE_UNDERLINE
)
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/test_sentencepiece.model'
)
'fixtures/test_sentencepiece.model'
)
...
@@ -30,7 +29,7 @@ class XLNetTokenizationTest(unittest.TestCase):
...
@@ -30,7 +29,7 @@ class XLNetTokenizationTest(unittest.TestCase):
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
keep_accents
=
True
)
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
keep_accents
=
True
)
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
tokenizer
.
save_pretrained
(
tmpdirname
)
tokenizer
.
save_pretrained
(
tmpdirname
)
create_and_check_tokenizer_commons
(
self
,
XLNetTokenizer
,
tmpdirname
)
create_and_check_tokenizer_commons
(
self
,
XLNetTokenizer
,
tmpdirname
)
...
...
pytorch_transformers/tokenization_utils.py
View file @
c079d7dd
...
@@ -231,8 +231,7 @@ class PreTrainedTokenizer(object):
...
@@ -231,8 +231,7 @@ class PreTrainedTokenizer(object):
# Add supplementary tokens.
# Add supplementary tokens.
if
added_tokens_file
is
not
None
:
if
added_tokens_file
is
not
None
:
added_tokens
=
json
.
load
(
open
(
added_tokens_file
,
encoding
=
"utf-8"
))
added_tok_encoder
=
json
.
load
(
open
(
added_tokens_file
,
encoding
=
"utf-8"
))
added_tok_encoder
=
dict
((
tok
,
len
(
tokenizer
)
+
i
)
for
i
,
tok
in
enumerate
(
added_tokens
))
added_tok_decoder
=
{
v
:
k
for
k
,
v
in
added_tok_encoder
.
items
()}
added_tok_decoder
=
{
v
:
k
for
k
,
v
in
added_tok_encoder
.
items
()}
tokenizer
.
added_tokens_encoder
.
update
(
added_tok_encoder
)
tokenizer
.
added_tokens_encoder
.
update
(
added_tok_encoder
)
tokenizer
.
added_tokens_decoder
.
update
(
added_tok_decoder
)
tokenizer
.
added_tokens_decoder
.
update
(
added_tok_decoder
)
...
@@ -256,7 +255,11 @@ class PreTrainedTokenizer(object):
...
@@ -256,7 +255,11 @@ class PreTrainedTokenizer(object):
f
.
write
(
json
.
dumps
(
self
.
special_tokens_map
,
ensure_ascii
=
False
))
f
.
write
(
json
.
dumps
(
self
.
special_tokens_map
,
ensure_ascii
=
False
))
with
open
(
added_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
added_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
added_tokens_decoder
,
ensure_ascii
=
False
))
if
self
.
added_tokens_encoder
:
out_str
=
json
.
dumps
(
self
.
added_tokens_decoder
,
ensure_ascii
=
False
)
else
:
out_str
=
u
"{}"
f
.
write
(
out_str
)
vocab_files
=
self
.
save_vocabulary
(
save_directory
)
vocab_files
=
self
.
save_vocabulary
(
save_directory
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment