Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2818e505
Unverified
Commit
2818e505
authored
Dec 24, 2019
by
Anthony MOI
Browse files
Add tests for fast tokenizers
parent
31c56f2e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
1 deletion
+67
-1
tests/test_tokenization_bert.py
tests/test_tokenization_bert.py
+27
-0
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+4
-0
tests/test_tokenization_gpt2.py
tests/test_tokenization_gpt2.py
+36
-1
No files found.
tests/test_tokenization_bert.py
View file @
2818e505
...
@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
...
@@ -21,6 +21,7 @@ from transformers.tokenization_bert import (
VOCAB_FILES_NAMES
,
VOCAB_FILES_NAMES
,
BasicTokenizer
,
BasicTokenizer
,
BertTokenizer
,
BertTokenizer
,
BertTokenizerFast
,
WordpieceTokenizer
,
WordpieceTokenizer
,
_is_control
,
_is_control
,
_is_punctuation
,
_is_punctuation
,
...
@@ -34,6 +35,7 @@ from .utils import slow
...
@@ -34,6 +35,7 @@ from .utils import slow
class
BertTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
class
BertTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
BertTokenizer
tokenizer_class
=
BertTokenizer
test_rust_tokenizer
=
True
def
setUp
(
self
):
def
setUp
(
self
):
super
(
BertTokenizationTest
,
self
).
setUp
()
super
(
BertTokenizationTest
,
self
).
setUp
()
...
@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -60,6 +62,9 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def
get_tokenizer
(
self
,
**
kwargs
):
def
get_tokenizer
(
self
,
**
kwargs
):
return
BertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
return
BertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_rust_tokenizer
(
self
,
**
kwargs
):
return
BertTokenizerFast
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
def
get_input_output_texts
(
self
):
input_text
=
"UNwant
\u00E9
d,running"
input_text
=
"UNwant
\u00E9
d,running"
output_text
=
"unwanted, running"
output_text
=
"unwanted, running"
...
@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -72,6 +77,28 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_rust_and_python_full_tokenizers
(
self
):
if
not
self
.
test_rust_tokenizer
:
return
tokenizer
=
self
.
get_tokenizer
()
rust_tokenizer
=
self
.
get_rust_tokenizer
(
add_special_tokens
=
False
)
sequence
=
u
"UNwant
\u00E9
d,running"
tokens
=
tokenizer
.
tokenize
(
sequence
)
rust_tokens
=
rust_tokenizer
.
tokenize
(
sequence
)
self
.
assertListEqual
(
tokens
,
rust_tokens
)
ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
rust_tokenizer
=
self
.
get_rust_tokenizer
()
ids
=
tokenizer
.
encode
(
sequence
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
def
test_chinese
(
self
):
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
tokenizer
=
BasicTokenizer
()
...
...
tests/test_tokenization_common.py
View file @
2818e505
...
@@ -23,6 +23,7 @@ import tempfile
...
@@ -23,6 +23,7 @@ import tempfile
class
TokenizerTesterMixin
:
class
TokenizerTesterMixin
:
tokenizer_class
=
None
tokenizer_class
=
None
test_rust_tokenizer
=
False
def
setUp
(
self
):
def
setUp
(
self
):
self
.
tmpdirname
=
tempfile
.
mkdtemp
()
self
.
tmpdirname
=
tempfile
.
mkdtemp
()
...
@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
...
@@ -33,6 +34,9 @@ class TokenizerTesterMixin:
def
get_tokenizer
(
self
,
**
kwargs
):
def
get_tokenizer
(
self
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
get_rust_tokenizer
(
self
,
**
kwargs
):
raise
NotImplementedError
def
get_input_output_texts
(
self
):
def
get_input_output_texts
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
...
tests/test_tokenization_gpt2.py
View file @
2818e505
...
@@ -18,7 +18,7 @@ import json
...
@@ -18,7 +18,7 @@ import json
import
os
import
os
import
unittest
import
unittest
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
,
GPT2TokenizerFast
from
.test_tokenization_common
import
TokenizerTesterMixin
from
.test_tokenization_common
import
TokenizerTesterMixin
...
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
...
@@ -26,6 +26,7 @@ from .test_tokenization_common import TokenizerTesterMixin
class
GPT2TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
class
GPT2TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
GPT2Tokenizer
tokenizer_class
=
GPT2Tokenizer
test_rust_tokenizer
=
True
def
setUp
(
self
):
def
setUp
(
self
):
super
(
GPT2TokenizationTest
,
self
).
setUp
()
super
(
GPT2TokenizationTest
,
self
).
setUp
()
...
@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -68,6 +69,10 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
kwargs
.
update
(
self
.
special_tokens_map
)
kwargs
.
update
(
self
.
special_tokens_map
)
return
GPT2Tokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
return
GPT2Tokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_rust_tokenizer
(
self
,
**
kwargs
):
kwargs
.
update
(
self
.
special_tokens_map
)
return
GPT2TokenizerFast
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
def
get_input_output_texts
(
self
):
input_text
=
"lower newer"
input_text
=
"lower newer"
output_text
=
"lower newer"
output_text
=
"lower newer"
...
@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -83,3 +88,33 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
def
test_rust_and_python_full_tokenizers
(
self
):
if
not
self
.
test_rust_tokenizer
:
return
tokenizer
=
self
.
get_tokenizer
()
rust_tokenizer
=
self
.
get_rust_tokenizer
(
add_special_tokens
=
False
,
add_prefix_space
=
True
)
sequence
=
u
"lower newer"
# Testing tokenization
tokens
=
tokenizer
.
tokenize
(
sequence
,
add_prefix_space
=
True
)
rust_tokens
=
rust_tokenizer
.
tokenize
(
sequence
)
self
.
assertListEqual
(
tokens
,
rust_tokens
)
# Testing conversion to ids without special tokens
ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
,
add_prefix_space
=
True
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
# Testing conversion to ids with special tokens
rust_tokenizer
=
self
.
get_rust_tokenizer
(
add_prefix_space
=
True
)
ids
=
tokenizer
.
encode
(
sequence
,
add_prefix_space
=
True
)
rust_ids
=
rust_tokenizer
.
encode
(
sequence
)
self
.
assertListEqual
(
ids
,
rust_ids
)
# Testing the unknown token
input_tokens
=
tokens
+
[
rust_tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
rust_tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment