Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2818e505
Unverified
Commit
2818e505
authored
Dec 24, 2019
by
Anthony MOI
Browse files
Add tests for fast tokenizers
parent
31c56f2e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
1 deletion
+67
-1
tests/test_tokenization_bert.py
tests/test_tokenization_bert.py
+27
-0
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+4
-0
tests/test_tokenization_gpt2.py
tests/test_tokenization_gpt2.py
+36
-1
No files found.
tests/test_tokenization_bert.py
View file @
2818e505
...
...
@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
VOCAB_FILES_NAMES
,
BasicTokenizer
,
BertTokenizer
,
BertTokenizerFast
,
WordpieceTokenizer
,
_is_control
,
_is_punctuation
,
...
...
@@ -34,6 +35,7 @@ from .utils import slow
class
BertTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
BertTokenizer
test_rust_tokenizer
=
True
def
setUp
(
self
):
super
(
BertTokenizationTest
,
self
).
setUp
()
...
...
@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def
get_tokenizer
(
self
,
**
kwargs
):
return
BertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_rust_tokenizer
(
self
,
**
kwargs
):
return
BertTokenizerFast
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
"UNwant
\u00E9
d,running"
output_text
=
"unwanted, running"
...
...
@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_rust_and_python_full_tokenizers
(
self
):
if
not
self
.
test_rust_tokenizer
:
return
tokenizer
=
self
.
get_tokenizer
()
rust_tokenizer
=
self
.
get_rust_tokenizer
(
add_special_tokens
=
False
)
sequence
=
u
"UNwant
\u00E9
d,running"
tokens
=
tokenizer
.
tokenize
(
sequence
)
rust_tokens
=
rust_tokenizer
.
tokenize
(
sequence
)
self
.
assertListEqual
(
tokens
,
rust_tokens
)
ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
rust_tokenizer
=
self
.
get_rust_tokenizer
()
ids
=
tokenizer
.
encode
(
sequence
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
...
...
tests/test_tokenization_common.py
View file @
2818e505
...
...
@@ -23,6 +23,7 @@ import tempfile
class
TokenizerTesterMixin
:
tokenizer_class
=
None
test_rust_tokenizer
=
False
def
setUp
(
self
):
self
.
tmpdirname
=
tempfile
.
mkdtemp
()
...
...
@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
def
get_tokenizer
(
self
,
**
kwargs
):
raise
NotImplementedError
def
get_rust_tokenizer
(
self
,
**
kwargs
):
raise
NotImplementedError
def
get_input_output_texts
(
self
):
raise
NotImplementedError
...
...
tests/test_tokenization_gpt2.py
View file @
2818e505
...
...
@@ -18,7 +18,7 @@ import json
import
os
import
unittest
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
,
GPT2TokenizerFast
from
.test_tokenization_common
import
TokenizerTesterMixin
...
...
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class
GPT2TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
GPT2Tokenizer
test_rust_tokenizer
=
True
def
setUp
(
self
):
super
(
GPT2TokenizationTest
,
self
).
setUp
()
...
...
@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs
.
update
(
self
.
special_tokens_map
)
return
GPT2Tokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_rust_tokenizer
(
self
,
**
kwargs
):
kwargs
.
update
(
self
.
special_tokens_map
)
return
GPT2TokenizerFast
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
"lower newer"
output_text
=
"lower newer"
...
...
@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
def
test_rust_and_python_full_tokenizers
(
self
):
if
not
self
.
test_rust_tokenizer
:
return
tokenizer
=
self
.
get_tokenizer
()
rust_tokenizer
=
self
.
get_rust_tokenizer
(
add_special_tokens
=
False
,
add_prefix_space
=
True
)
sequence
=
u
"lower newer"
# Testing tokenization
tokens
=
tokenizer
.
tokenize
(
sequence
,
add_prefix_space
=
True
)
rust_tokens
=
rust_tokenizer
.
tokenize
(
sequence
)
self
.
assertListEqual
(
tokens
,
rust_tokens
)
# Testing conversion to ids without special tokens
ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
,
add_prefix_space
=
True
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
# Testing conversion to ids with special tokens
rust_tokenizer
=
self
.
get_rust_tokenizer
(
add_prefix_space
=
True
)
ids
=
tokenizer
.
encode
(
sequence
,
add_prefix_space
=
True
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
# Testing the unknown token
input_tokens
=
tokens
+
[
rust_tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
rust_tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment