Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
abe734ca
Commit
abe734ca
authored
Aug 30, 2019
by
thomwolf
Browse files
fix GPT-2 and RoBERTa tests to be clean now
parent
0f5a7994
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
18 deletions
+20
-18
pytorch_transformers/tests/tokenization_gpt2_test.py
pytorch_transformers/tests/tokenization_gpt2_test.py
+9
-8
pytorch_transformers/tests/tokenization_roberta_test.py
pytorch_transformers/tests/tokenization_roberta_test.py
+9
-8
pytorch_transformers/tests/tokenization_tests_commons.py
pytorch_transformers/tests/tokenization_tests_commons.py
+2
-2
No files found.
pytorch_transformers/tests/tokenization_gpt2_test.py
View file @
abe734ca
...
...
@@ -31,17 +31,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"lo"
,
"low"
,
"er"
,
"low"
,
"lowest"
,
"newer"
,
"wider"
,
"<unk>"
]
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
merges
=
[
"#version: 0.2"
,
"
\u0120
l"
,
"
\u0120
l o"
,
"
\u0120
lo w"
,
"e r"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
self
.
vocab_file
,
"w"
)
as
fp
:
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
with
open
(
self
.
merges_file
,
"w"
)
as
fp
:
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
def
get_tokenizer
(
self
):
...
...
@@ -49,18 +50,18 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
def
get_input_output_texts
(
self
):
input_text
=
u
"lower newer"
output_text
=
u
"lower
<unk>
newer"
output_text
=
u
"
lower
newer"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
GPT2Tokenizer
(
self
.
vocab_file
,
self
.
merges_file
,
**
self
.
special_tokens_map
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er"
]
bpe_tokens
=
[
"
\u0120
low"
,
"er"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
1
3
,
1
2
,
1
7
]
input_bpe_tokens
=
[
1
4
,
1
5
,
1
9
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
...
...
pytorch_transformers/tests/tokenization_roberta_test.py
View file @
abe734ca
...
...
@@ -30,17 +30,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"lo"
,
"low"
,
"er"
,
"low"
,
"lowest"
,
"newer"
,
"wider"
,
"<unk>"
]
"
\u0120
"
,
"
\u0120
l"
,
"
\u0120
n"
,
"
\u0120
lo"
,
"
\u0120
low"
,
"er"
,
"
\u0120
lowest"
,
"
\u0120
newer"
,
"
\u0120
wider"
,
"<unk>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
merges
=
[
"#version: 0.2"
,
"
\u0120
l"
,
"
\u0120
l o"
,
"
\u0120
lo w"
,
"e r"
,
""
]
self
.
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
self
.
vocab_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
self
.
merges_file
=
os
.
path
.
join
(
self
.
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
self
.
vocab_file
,
"w"
)
as
fp
:
with
open
(
self
.
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
with
open
(
self
.
merges_file
,
"w"
)
as
fp
:
with
open
(
self
.
merges_file
,
"w"
,
encoding
=
"utf-8"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
def
get_tokenizer
(
self
):
...
...
@@ -48,18 +49,18 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
def
get_input_output_texts
(
self
):
input_text
=
u
"lower newer"
output_text
=
u
"lower
<unk>
newer"
output_text
=
u
"
lower
newer"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
RobertaTokenizer
(
self
.
vocab_file
,
self
.
merges_file
,
**
self
.
special_tokens_map
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er"
]
bpe_tokens
=
[
"
\u0120
low"
,
"er"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
1
3
,
1
2
,
1
7
]
input_bpe_tokens
=
[
1
4
,
1
5
,
1
9
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
...
...
pytorch_transformers/tests/tokenization_tests_commons.py
View file @
abe734ca
...
...
@@ -129,7 +129,7 @@ class CommonTestCases:
self
.
assertGreater
(
tokens
[
-
2
],
tokenizer
.
vocab_size
-
1
)
self
.
assertGreater
(
tokens
[
-
2
],
tokens
[
-
3
])
self
.
assertEqual
(
tokens
[
0
],
tokenizer
.
eos_token_id
)
self
.
assertEqual
(
tokens
[
-
2
],
tokenizer
.
eos
_token_id
)
self
.
assertEqual
(
tokens
[
-
2
],
tokenizer
.
pad
_token_id
)
def
test_required_methods_tokenizer
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment