Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5daca95d
Unverified
Commit
5daca95d
authored
Dec 22, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 22, 2019
Browse files
Merge pull request #2268 from aaugustin/improve-repository-structure
Improve repository structure
parents
54abc67a
00204f2b
Changes
167
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
559 additions
and
111 deletions
+559
-111
tests/test_modeling_tf_xlm.py
tests/test_modeling_tf_xlm.py
+3
-7
tests/test_modeling_tf_xlnet.py
tests/test_modeling_tf_xlnet.py
+3
-7
tests/test_modeling_transfo_xl.py
tests/test_modeling_transfo_xl.py
+3
-7
tests/test_modeling_xlm.py
tests/test_modeling_xlm.py
+3
-7
tests/test_modeling_xlnet.py
tests/test_modeling_xlnet.py
+3
-7
tests/test_optimization.py
tests/test_optimization.py
+1
-5
tests/test_optimization_tf.py
tests/test_optimization_tf.py
+0
-4
tests/test_pipelines.py
tests/test_pipelines.py
+2
-5
tests/test_tokenization_albert.py
tests/test_tokenization_albert.py
+2
-6
tests/test_tokenization_auto.py
tests/test_tokenization_auto.py
+0
-4
tests/test_tokenization_bert.py
tests/test_tokenization_bert.py
+2
-6
tests/test_tokenization_bert_japanese.py
tests/test_tokenization_bert_japanese.py
+4
-3
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+520
-0
tests/test_tokenization_ctrl.py
tests/test_tokenization_ctrl.py
+2
-6
tests/test_tokenization_distilbert.py
tests/test_tokenization_distilbert.py
+1
-7
tests/test_tokenization_gpt2.py
tests/test_tokenization_gpt2.py
+2
-6
tests/test_tokenization_openai.py
tests/test_tokenization_openai.py
+2
-6
tests/test_tokenization_roberta.py
tests/test_tokenization_roberta.py
+2
-6
tests/test_tokenization_t5.py
tests/test_tokenization_t5.py
+2
-6
tests/test_tokenization_transfo_xl.py
tests/test_tokenization_transfo_xl.py
+2
-6
No files found.
t
ransformer
s/test
s/
modeling_tf_xlm
_test
.py
→
t
est
s/test
_
modeling_tf_xlm.py
View file @
5daca95d
...
@@ -18,8 +18,8 @@ import unittest
...
@@ -18,8 +18,8 @@ import unittest
from
transformers
import
is_tf_available
from
transformers
import
is_tf_available
from
.configuration_common
_test
import
ConfigTester
from
.
test_
configuration_common
import
ConfigTester
from
.modeling_tf_common
_test
import
TF
CommonTestCases
,
ids_tensor
from
.
test_
modeling_tf_common
import
TF
ModelTesterMixin
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
...
@@ -36,7 +36,7 @@ if is_tf_available():
...
@@ -36,7 +36,7 @@ if is_tf_available():
@
require_tf
@
require_tf
class
TFXLMModelTest
(
TF
CommonTestCases
.
TFCommonModelTester
):
class
TFXLMModelTest
(
TF
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
all_model_classes
=
(
(
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
)
(
TFXLMModel
,
TFXLMWithLMHeadModel
,
TFXLMForSequenceClassification
,
TFXLMForQuestionAnsweringSimple
)
...
@@ -306,7 +306,3 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -306,7 +306,3 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
for
model_name
in
list
(
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFXLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
model
=
TFXLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
modeling_tf_xlnet
_test
.py
→
t
est
s/test
_
modeling_tf_xlnet.py
View file @
5daca95d
...
@@ -19,8 +19,8 @@ import unittest
...
@@ -19,8 +19,8 @@ import unittest
from
transformers
import
XLNetConfig
,
is_tf_available
from
transformers
import
XLNetConfig
,
is_tf_available
from
.configuration_common
_test
import
ConfigTester
from
.
test_
configuration_common
import
ConfigTester
from
.modeling_tf_common
_test
import
TF
CommonTestCases
,
ids_tensor
from
.
test_
modeling_tf_common
import
TF
ModelTesterMixin
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
from
.utils
import
CACHE_DIR
,
require_tf
,
slow
...
@@ -38,7 +38,7 @@ if is_tf_available():
...
@@ -38,7 +38,7 @@ if is_tf_available():
@
require_tf
@
require_tf
class
TFXLNetModelTest
(
TF
CommonTestCases
.
TFCommonModelTester
):
class
TFXLNetModelTest
(
TF
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
all_model_classes
=
(
(
(
...
@@ -401,7 +401,3 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -401,7 +401,3 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
for
model_name
in
list
(
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFXLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
model
=
TFXLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
modeling_transfo_xl
_test
.py
→
t
est
s/test
_
modeling_transfo_xl.py
View file @
5daca95d
...
@@ -19,8 +19,8 @@ import unittest
...
@@ -19,8 +19,8 @@ import unittest
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.configuration_common
_test
import
ConfigTester
from
.
test_
configuration_common
import
ConfigTester
from
.modeling_common
_test
import
CommonTestCases
,
ids_tensor
from
.
test_
modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
...
@@ -31,7 +31,7 @@ if is_torch_available():
...
@@ -31,7 +31,7 @@ if is_torch_available():
@
require_torch
@
require_torch
class
TransfoXLModelTest
(
CommonTestCases
.
CommonModelTester
):
class
TransfoXLModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
if
is_torch_available
()
else
()
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
if
is_torch_available
()
else
()
test_pruning
=
False
test_pruning
=
False
...
@@ -208,7 +208,3 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
...
@@ -208,7 +208,3 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
for
model_name
in
list
(
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
model
=
TransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
modeling_xlm
_test
.py
→
t
est
s/test
_
modeling_xlm.py
View file @
5daca95d
...
@@ -18,8 +18,8 @@ import unittest
...
@@ -18,8 +18,8 @@ import unittest
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.configuration_common
_test
import
ConfigTester
from
.
test_
configuration_common
import
ConfigTester
from
.modeling_common
_test
import
CommonTestCases
,
ids_tensor
from
.
test_
modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
...
@@ -36,7 +36,7 @@ if is_torch_available():
...
@@ -36,7 +36,7 @@ if is_torch_available():
@
require_torch
@
require_torch
class
XLMModelTest
(
CommonTestCases
.
CommonModelTester
):
class
XLMModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
all_model_classes
=
(
(
(
...
@@ -390,7 +390,3 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
...
@@ -390,7 +390,3 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
for
model_name
in
list
(
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
model
=
XLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
modeling_xlnet
_test
.py
→
t
est
s/test
_
modeling_xlnet.py
View file @
5daca95d
...
@@ -19,8 +19,8 @@ import unittest
...
@@ -19,8 +19,8 @@ import unittest
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.configuration_common
_test
import
ConfigTester
from
.
test_
configuration_common
import
ConfigTester
from
.modeling_common
_test
import
CommonTestCases
,
ids_tensor
from
.
test_
modeling_common
import
ModelTesterMixin
,
ids_tensor
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
from
.utils
import
CACHE_DIR
,
require_torch
,
slow
,
torch_device
...
@@ -39,7 +39,7 @@ if is_torch_available():
...
@@ -39,7 +39,7 @@ if is_torch_available():
@
require_torch
@
require_torch
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
class
XLNetModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
all_model_classes
=
(
(
(
...
@@ -499,7 +499,3 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -499,7 +499,3 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
for
model_name
in
list
(
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
model
=
XLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
CACHE_DIR
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
optimization
_test
.py
→
t
est
s/test
_
optimization.py
View file @
5daca95d
...
@@ -19,7 +19,7 @@ import unittest
...
@@ -19,7 +19,7 @@ import unittest
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.tokenization_
tests_
common
s
import
TemporaryDirectory
from
.
test_
tokenization_common
import
TemporaryDirectory
from
.utils
import
require_torch
from
.utils
import
require_torch
...
@@ -150,7 +150,3 @@ class ScheduleInitTest(unittest.TestCase):
...
@@ -150,7 +150,3 @@ class ScheduleInitTest(unittest.TestCase):
)
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
optimization_tf
_test
.py
→
t
est
s/test
_
optimization_tf.py
View file @
5daca95d
...
@@ -83,7 +83,3 @@ class OptimizationFTest(unittest.TestCase):
...
@@ -83,7 +83,3 @@ class OptimizationFTest(unittest.TestCase):
self
.
assertEqual
(
accumulator
.
step
,
0
)
self
.
assertEqual
(
accumulator
.
step
,
0
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
0
].
value
().
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
0
].
value
().
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
1
].
value
().
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
1
].
value
().
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
pipelines
_test
.py
→
t
est
s/test
_
pipelines.py
View file @
5daca95d
...
@@ -2,7 +2,8 @@ import unittest
...
@@ -2,7 +2,8 @@ import unittest
from
typing
import
Iterable
from
typing
import
Iterable
from
transformers
import
pipeline
from
transformers
import
pipeline
from
transformers.tests.utils
import
require_tf
,
require_torch
from
.utils
import
require_tf
,
require_torch
QA_FINETUNED_MODELS
=
{
QA_FINETUNED_MODELS
=
{
...
@@ -204,7 +205,3 @@ class MultiColumnInputTestCase(unittest.TestCase):
...
@@ -204,7 +205,3 @@ class MultiColumnInputTestCase(unittest.TestCase):
for
tokenizer
,
model
,
config
in
TF_QA_FINETUNED_MODELS
:
for
tokenizer
,
model
,
config
in
TF_QA_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"question-answering"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
nlp
=
pipeline
(
task
=
"question-answering"
,
model
=
model
,
config
=
config
,
tokenizer
=
tokenizer
)
self
.
_test_multicolumn_pipeline
(
nlp
,
valid_samples
,
invalid_samples
,
mandatory_output_keys
)
self
.
_test_multicolumn_pipeline
(
nlp
,
valid_samples
,
invalid_samples
,
mandatory_output_keys
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_albert
_test
.py
→
t
est
s/test
_
tokenization_albert.py
View file @
5daca95d
...
@@ -19,13 +19,13 @@ import unittest
...
@@ -19,13 +19,13 @@ import unittest
from
transformers.tokenization_albert
import
AlbertTokenizer
from
transformers.tokenization_albert
import
AlbertTokenizer
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/spiece.model"
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/spiece.model"
)
class
AlbertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
AlbertTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
AlbertTokenizer
tokenizer_class
=
AlbertTokenizer
...
@@ -78,7 +78,3 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -78,7 +78,3 @@ class AlbertTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
tokenizer
.
sep_token_id
tokenizer
.
sep_token_id
]
]
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_auto
_test
.py
→
t
est
s/test
_
tokenization_auto.py
View file @
5daca95d
...
@@ -49,7 +49,3 @@ class AutoTokenizerTest(unittest.TestCase):
...
@@ -49,7 +49,3 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
)
self
.
assertIsInstance
(
tokenizer
,
BertTokenizer
)
self
.
assertIsInstance
(
tokenizer
,
BertTokenizer
)
self
.
assertEqual
(
len
(
tokenizer
),
12
)
self
.
assertEqual
(
len
(
tokenizer
),
12
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_bert
_test
.py
→
t
est
s/test
_
tokenization_bert.py
View file @
5daca95d
...
@@ -28,11 +28,11 @@ from transformers.tokenization_bert import (
...
@@ -28,11 +28,11 @@ from transformers.tokenization_bert import (
_is_whitespace
,
_is_whitespace
,
)
)
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
from
.utils
import
slow
from
.utils
import
slow
class
BertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
BertTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
BertTokenizer
tokenizer_class
=
BertTokenizer
...
@@ -146,7 +146,3 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -146,7 +146,3 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert
encoded_sentence
==
[
101
]
+
text
+
[
102
]
assert
encoded_sentence
==
[
101
]
+
text
+
[
102
]
assert
encoded_pair
==
[
101
]
+
text
+
[
102
]
+
text_2
+
[
102
]
assert
encoded_pair
==
[
101
]
+
text
+
[
102
]
+
text_2
+
[
102
]
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_bert_japanese
_test
.py
→
t
est
s/test
_
tokenization_bert_japanese.py
View file @
5daca95d
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
os
import
unittest
from
io
import
open
from
io
import
open
from
transformers.tokenization_bert
import
WordpieceTokenizer
from
transformers.tokenization_bert
import
WordpieceTokenizer
...
@@ -25,12 +26,12 @@ from transformers.tokenization_bert_japanese import (
...
@@ -25,12 +26,12 @@ from transformers.tokenization_bert_japanese import (
MecabTokenizer
,
MecabTokenizer
,
)
)
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
from
.utils
import
custom_tokenizers
,
slow
from
.utils
import
custom_tokenizers
,
slow
@
custom_tokenizers
@
custom_tokenizers
class
BertJapaneseTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
BertJapaneseTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
BertJapaneseTokenizer
tokenizer_class
=
BertJapaneseTokenizer
...
@@ -130,7 +131,7 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -130,7 +131,7 @@ class BertJapaneseTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert
encoded_pair
==
[
2
]
+
text
+
[
3
]
+
text_2
+
[
3
]
assert
encoded_pair
==
[
2
]
+
text
+
[
3
]
+
text_2
+
[
3
]
class
BertJapaneseCharacterTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
BertJapaneseCharacterTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
BertJapaneseTokenizer
tokenizer_class
=
BertJapaneseTokenizer
...
...
tests/test_tokenization_common.py
0 → 100644
View file @
5daca95d
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
shutil
import
sys
import
tempfile
from
io
import
open
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
import
pickle
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
class
TokenizerTesterMixin
:
tokenizer_class
=
None
def
setUp
(
self
):
self
.
tmpdirname
=
tempfile
.
mkdtemp
()
def
tearDown
(
self
):
shutil
.
rmtree
(
self
.
tmpdirname
)
def
get_tokenizer
(
self
,
**
kwargs
):
raise
NotImplementedError
def
get_input_output_texts
(
self
):
raise
NotImplementedError
def
test_tokenizers_common_properties
(
self
):
tokenizer
=
self
.
get_tokenizer
()
attributes_list
=
[
"bos_token"
,
"eos_token"
,
"unk_token"
,
"sep_token"
,
"pad_token"
,
"cls_token"
,
"mask_token"
,
]
for
attr
in
attributes_list
:
self
.
assertTrue
(
hasattr
(
tokenizer
,
attr
))
self
.
assertTrue
(
hasattr
(
tokenizer
,
attr
+
"_id"
))
self
.
assertTrue
(
hasattr
(
tokenizer
,
"additional_special_tokens"
))
self
.
assertTrue
(
hasattr
(
tokenizer
,
"additional_special_tokens_ids"
))
attributes_list
=
[
"max_len"
,
"init_inputs"
,
"init_kwargs"
,
"added_tokens_encoder"
,
"added_tokens_decoder"
]
for
attr
in
attributes_list
:
self
.
assertTrue
(
hasattr
(
tokenizer
,
attr
))
def
test_save_and_load_tokenizer
(
self
):
# safety check on max_len default value so we are sure the test works
tokenizer
=
self
.
get_tokenizer
()
self
.
assertNotEqual
(
tokenizer
.
max_len
,
42
)
# Now let's start the test
tokenizer
=
self
.
get_tokenizer
(
max_len
=
42
)
before_tokens
=
tokenizer
.
encode
(
"He is very happy, UNwant
\u00E9
d,running"
,
add_special_tokens
=
False
)
with
TemporaryDirectory
()
as
tmpdirname
:
tokenizer
.
save_pretrained
(
tmpdirname
)
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
tmpdirname
)
after_tokens
=
tokenizer
.
encode
(
"He is very happy, UNwant
\u00E9
d,running"
,
add_special_tokens
=
False
)
self
.
assertListEqual
(
before_tokens
,
after_tokens
)
self
.
assertEqual
(
tokenizer
.
max_len
,
42
)
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
tmpdirname
,
max_len
=
43
)
self
.
assertEqual
(
tokenizer
.
max_len
,
43
)
def
test_pickle_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
self
.
assertIsNotNone
(
tokenizer
)
text
=
"Munich and Berlin are nice cities"
subwords
=
tokenizer
.
tokenize
(
text
)
with
TemporaryDirectory
()
as
tmpdirname
:
filename
=
os
.
path
.
join
(
tmpdirname
,
"tokenizer.bin"
)
with
open
(
filename
,
"wb"
)
as
handle
:
pickle
.
dump
(
tokenizer
,
handle
)
with
open
(
filename
,
"rb"
)
as
handle
:
tokenizer_new
=
pickle
.
load
(
handle
)
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
def
test_added_tokens_do_lower_case
(
self
):
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
True
)
special_token
=
tokenizer
.
all_special_tokens
[
0
]
text
=
special_token
+
" aaaaa bbbbbb low cccccccccdddddddd l "
+
special_token
text2
=
special_token
+
" AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l "
+
special_token
toks0
=
tokenizer
.
tokenize
(
text
)
# toks before adding new_toks
new_toks
=
[
"aaaaa bbbbbb"
,
"cccccccccdddddddd"
,
"AAAAA BBBBBB"
,
"CCCCCCCCCDDDDDDDD"
]
added
=
tokenizer
.
add_tokens
(
new_toks
)
self
.
assertEqual
(
added
,
2
)
toks
=
tokenizer
.
tokenize
(
text
)
toks2
=
tokenizer
.
tokenize
(
text2
)
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
# toks0 should be longer
self
.
assertListEqual
(
toks
,
toks2
)
# Check that none of the special tokens are lowercased
sequence_with_special_tokens
=
"A "
+
" yEs "
.
join
(
tokenizer
.
all_special_tokens
)
+
" B"
tokenized_sequence
=
tokenizer
.
tokenize
(
sequence_with_special_tokens
)
for
special_token
in
tokenizer
.
all_special_tokens
:
self
.
assertTrue
(
special_token
in
tokenized_sequence
)
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
False
)
added
=
tokenizer
.
add_tokens
(
new_toks
)
self
.
assertEqual
(
added
,
4
)
toks
=
tokenizer
.
tokenize
(
text
)
toks2
=
tokenizer
.
tokenize
(
text2
)
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
# Length should still be the same
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
self
.
assertNotEqual
(
toks
[
1
],
toks2
[
1
])
# But at least the first non-special tokens should differ
def
test_add_tokens_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
vocab_size
=
tokenizer
.
vocab_size
all_size
=
len
(
tokenizer
)
self
.
assertNotEqual
(
vocab_size
,
0
)
self
.
assertEqual
(
vocab_size
,
all_size
)
new_toks
=
[
"aaaaa bbbbbb"
,
"cccccccccdddddddd"
]
added_toks
=
tokenizer
.
add_tokens
(
new_toks
)
vocab_size_2
=
tokenizer
.
vocab_size
all_size_2
=
len
(
tokenizer
)
self
.
assertNotEqual
(
vocab_size_2
,
0
)
self
.
assertEqual
(
vocab_size
,
vocab_size_2
)
self
.
assertEqual
(
added_toks
,
len
(
new_toks
))
self
.
assertEqual
(
all_size_2
,
all_size
+
len
(
new_toks
))
tokens
=
tokenizer
.
encode
(
"aaaaa bbbbbb low cccccccccdddddddd l"
,
add_special_tokens
=
False
)
out_string
=
tokenizer
.
decode
(
tokens
)
self
.
assertGreaterEqual
(
len
(
tokens
),
4
)
self
.
assertGreater
(
tokens
[
0
],
tokenizer
.
vocab_size
-
1
)
self
.
assertGreater
(
tokens
[
-
2
],
tokenizer
.
vocab_size
-
1
)
new_toks_2
=
{
"eos_token"
:
">>>>|||<||<<|<<"
,
"pad_token"
:
"<<<<<|||>|>>>>|>"
}
added_toks_2
=
tokenizer
.
add_special_tokens
(
new_toks_2
)
vocab_size_3
=
tokenizer
.
vocab_size
all_size_3
=
len
(
tokenizer
)
self
.
assertNotEqual
(
vocab_size_3
,
0
)
self
.
assertEqual
(
vocab_size
,
vocab_size_3
)
self
.
assertEqual
(
added_toks_2
,
len
(
new_toks_2
))
self
.
assertEqual
(
all_size_3
,
all_size_2
+
len
(
new_toks_2
))
tokens
=
tokenizer
.
encode
(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l"
,
add_special_tokens
=
False
)
out_string
=
tokenizer
.
decode
(
tokens
)
self
.
assertGreaterEqual
(
len
(
tokens
),
6
)
self
.
assertGreater
(
tokens
[
0
],
tokenizer
.
vocab_size
-
1
)
self
.
assertGreater
(
tokens
[
0
],
tokens
[
1
])
self
.
assertGreater
(
tokens
[
-
2
],
tokenizer
.
vocab_size
-
1
)
self
.
assertGreater
(
tokens
[
-
2
],
tokens
[
-
3
])
self
.
assertEqual
(
tokens
[
0
],
tokenizer
.
eos_token_id
)
self
.
assertEqual
(
tokens
[
-
2
],
tokenizer
.
pad_token_id
)
def
test_add_special_tokens
(
self
):
tokenizer
=
self
.
get_tokenizer
()
input_text
,
output_text
=
self
.
get_input_output_texts
()
special_token
=
"[SPECIAL TOKEN]"
tokenizer
.
add_special_tokens
({
"cls_token"
:
special_token
})
encoded_special_token
=
tokenizer
.
encode
(
special_token
,
add_special_tokens
=
False
)
assert
len
(
encoded_special_token
)
==
1
text
=
" "
.
join
([
input_text
,
special_token
,
output_text
])
encoded
=
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
input_encoded
=
tokenizer
.
encode
(
input_text
,
add_special_tokens
=
False
)
output_encoded
=
tokenizer
.
encode
(
output_text
,
add_special_tokens
=
False
)
special_token_id
=
tokenizer
.
encode
(
special_token
,
add_special_tokens
=
False
)
assert
encoded
==
input_encoded
+
special_token_id
+
output_encoded
decoded
=
tokenizer
.
decode
(
encoded
,
skip_special_tokens
=
True
)
assert
special_token
not
in
decoded
def
test_required_methods_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
input_text
,
output_text
=
self
.
get_input_output_texts
()
tokens
=
tokenizer
.
tokenize
(
input_text
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
ids_2
=
tokenizer
.
encode
(
input_text
,
add_special_tokens
=
False
)
self
.
assertListEqual
(
ids
,
ids_2
)
tokens_2
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
text_2
=
tokenizer
.
decode
(
ids
)
self
.
assertEqual
(
text_2
,
output_text
)
self
.
assertNotEqual
(
len
(
tokens_2
),
0
)
self
.
assertIsInstance
(
text_2
,
(
str
,
unicode
))
def
test_encode_decode_with_spaces
(
self
):
tokenizer
=
self
.
get_tokenizer
()
new_toks
=
[
"[ABC]"
,
"[DEF]"
,
"GHI IHG"
]
tokenizer
.
add_tokens
(
new_toks
)
input
=
"[ABC] [DEF] [ABC] GHI IHG [DEF]"
encoded
=
tokenizer
.
encode
(
input
,
add_special_tokens
=
False
)
decoded
=
tokenizer
.
decode
(
encoded
)
self
.
assertEqual
(
decoded
,
input
)
def
test_pretrained_model_lists
(
self
):
weights_list
=
list
(
self
.
tokenizer_class
.
max_model_input_sizes
.
keys
())
weights_lists_2
=
[]
for
file_id
,
map_list
in
self
.
tokenizer_class
.
pretrained_vocab_files_map
.
items
():
weights_lists_2
.
append
(
list
(
map_list
.
keys
()))
for
weights_list_2
in
weights_lists_2
:
self
.
assertListEqual
(
weights_list
,
weights_list_2
)
def
test_mask_output
(
self
):
if
sys
.
version_info
<=
(
3
,
0
):
return
tokenizer
=
self
.
get_tokenizer
()
if
tokenizer
.
build_inputs_with_special_tokens
.
__qualname__
.
split
(
"."
)[
0
]
!=
"PreTrainedTokenizer"
:
seq_0
=
"Test this method."
seq_1
=
"With these inputs."
information
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
add_special_tokens
=
True
)
sequences
,
mask
=
information
[
"input_ids"
],
information
[
"token_type_ids"
]
self
.
assertEqual
(
len
(
sequences
),
len
(
mask
))
def
test_number_of_added_tokens
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"Test this method."
seq_1
=
"With these inputs."
sequences
=
tokenizer
.
encode
(
seq_0
,
seq_1
,
add_special_tokens
=
False
)
attached_sequences
=
tokenizer
.
encode
(
seq_0
,
seq_1
,
add_special_tokens
=
True
)
# Method is implemented (e.g. not GPT-2)
if
len
(
attached_sequences
)
!=
2
:
self
.
assertEqual
(
tokenizer
.
num_added_tokens
(
pair
=
True
),
len
(
attached_sequences
)
-
len
(
sequences
))
def
test_maximum_encoding_length_single_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
stride
=
2
sequence
=
tokenizer
.
encode
(
seq_0
,
add_special_tokens
=
False
)
num_added_tokens
=
tokenizer
.
num_added_tokens
()
total_length
=
len
(
sequence
)
+
num_added_tokens
information
=
tokenizer
.
encode_plus
(
seq_0
,
max_length
=
total_length
-
2
,
add_special_tokens
=
True
,
stride
=
stride
,
return_overflowing_tokens
=
True
,
)
truncated_sequence
=
information
[
"input_ids"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
self
.
assertEqual
(
len
(
overflowing_tokens
),
2
+
stride
)
self
.
assertEqual
(
overflowing_tokens
,
sequence
[
-
(
2
+
stride
)
:])
self
.
assertEqual
(
len
(
truncated_sequence
),
total_length
-
2
)
self
.
assertEqual
(
truncated_sequence
,
tokenizer
.
build_inputs_with_special_tokens
(
sequence
[:
-
2
]))
def
test_maximum_encoding_length_pair_input
(
self
):
tokenizer
=
self
.
get_tokenizer
()
seq_0
=
"This is a sentence to be encoded."
seq_1
=
"This is another sentence to be encoded."
stride
=
2
sequence_0_no_special_tokens
=
tokenizer
.
encode
(
seq_0
,
add_special_tokens
=
False
)
sequence_1_no_special_tokens
=
tokenizer
.
encode
(
seq_1
,
add_special_tokens
=
False
)
sequence
=
tokenizer
.
encode
(
seq_0
,
seq_1
,
add_special_tokens
=
True
)
truncated_second_sequence
=
tokenizer
.
build_inputs_with_special_tokens
(
tokenizer
.
encode
(
seq_0
,
add_special_tokens
=
False
),
tokenizer
.
encode
(
seq_1
,
add_special_tokens
=
False
)[:
-
2
],
)
information
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
add_special_tokens
=
True
,
stride
=
stride
,
truncation_strategy
=
"only_second"
,
return_overflowing_tokens
=
True
,
)
information_first_truncated
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
add_special_tokens
=
True
,
stride
=
stride
,
truncation_strategy
=
"only_first"
,
return_overflowing_tokens
=
True
,
)
truncated_sequence
=
information
[
"input_ids"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
overflowing_tokens_first_truncated
=
information_first_truncated
[
"overflowing_tokens"
]
self
.
assertEqual
(
len
(
overflowing_tokens
),
2
+
stride
)
self
.
assertEqual
(
overflowing_tokens
,
sequence_1_no_special_tokens
[
-
(
2
+
stride
)
:])
self
.
assertEqual
(
overflowing_tokens_first_truncated
,
sequence_0_no_special_tokens
[
-
(
2
+
stride
)
:])
self
.
assertEqual
(
len
(
truncated_sequence
),
len
(
sequence
)
-
2
)
self
.
assertEqual
(
truncated_sequence
,
truncated_second_sequence
)
def
test_encode_input_type
(
self
):
tokenizer
=
self
.
get_tokenizer
()
sequence
=
"Let's encode this sequence"
tokens
=
tokenizer
.
tokenize
(
sequence
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
formatted_input
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
True
)
self
.
assertEqual
(
tokenizer
.
encode
(
tokens
,
add_special_tokens
=
True
),
formatted_input
)
self
.
assertEqual
(
tokenizer
.
encode
(
input_ids
,
add_special_tokens
=
True
),
formatted_input
)
def
test_special_tokens_mask
(
self
):
tokenizer
=
self
.
get_tokenizer
()
sequence_0
=
"Encode this."
sequence_1
=
"This one too please."
# Testing single inputs
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
filtered_sequence
=
[
(
x
if
not
special_tokens_mask
[
i
]
else
None
)
for
i
,
x
in
enumerate
(
encoded_sequence_w_special
)
]
filtered_sequence
=
[
x
for
x
in
filtered_sequence
if
x
is
not
None
]
self
.
assertEqual
(
encoded_sequence
,
filtered_sequence
)
# Testing inputs pairs
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
+
tokenizer
.
encode
(
sequence_1
,
add_special_tokens
=
False
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
sequence_1
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
filtered_sequence
=
[
(
x
if
not
special_tokens_mask
[
i
]
else
None
)
for
i
,
x
in
enumerate
(
encoded_sequence_w_special
)
]
filtered_sequence
=
[
x
for
x
in
filtered_sequence
if
x
is
not
None
]
self
.
assertEqual
(
encoded_sequence
,
filtered_sequence
)
# Testing with already existing special tokens
if
tokenizer
.
cls_token_id
==
tokenizer
.
unk_token_id
and
tokenizer
.
cls_token_id
==
tokenizer
.
unk_token_id
:
tokenizer
.
add_special_tokens
({
"cls_token"
:
"</s>"
,
"sep_token"
:
"<s>"
})
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask_orig
=
encoded_sequence_dict
[
"special_tokens_mask"
]
special_tokens_mask
=
tokenizer
.
get_special_tokens_mask
(
encoded_sequence_w_special
,
already_has_special_tokens
=
True
)
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
self
.
assertEqual
(
special_tokens_mask_orig
,
special_tokens_mask
)
def
test_padding_to_max_length
(
self
):
tokenizer
=
self
.
get_tokenizer
()
sequence
=
"Sequence"
padding_size
=
10
padding_idx
=
tokenizer
.
pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer
.
padding_side
=
"right"
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
)
padded_sequence_length
=
len
(
padded_sequence
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
encoded_sequence
+
[
padding_idx
]
*
padding_size
==
padded_sequence
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer
.
padding_side
=
"left"
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
)
padded_sequence_length
=
len
(
padded_sequence
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
[
padding_idx
]
*
padding_size
+
encoded_sequence
==
padded_sequence
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
tokenizer
.
padding_side
=
"right"
padded_sequence_right
=
tokenizer
.
encode
(
sequence
,
pad_to_max_length
=
True
)
padded_sequence_right_length
=
len
(
padded_sequence_right
)
tokenizer
.
padding_side
=
"left"
padded_sequence_left
=
tokenizer
.
encode
(
sequence
,
pad_to_max_length
=
True
)
padded_sequence_left_length
=
len
(
padded_sequence_left
)
assert
sequence_length
==
padded_sequence_right_length
assert
encoded_sequence
==
padded_sequence_right
assert
sequence_length
==
padded_sequence_left_length
assert
encoded_sequence
==
padded_sequence_left
def
test_encode_plus_with_padding
(
self
):
tokenizer
=
self
.
get_tokenizer
()
sequence
=
"Sequence"
padding_size
=
10
padding_idx
=
tokenizer
.
pad_token_id
token_type_padding_idx
=
tokenizer
.
pad_token_type_id
encoded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
return_special_tokens_mask
=
True
)
input_ids
=
encoded_sequence
[
"input_ids"
]
token_type_ids
=
encoded_sequence
[
"token_type_ids"
]
attention_mask
=
encoded_sequence
[
"attention_mask"
]
special_tokens_mask
=
encoded_sequence
[
"special_tokens_mask"
]
sequence_length
=
len
(
input_ids
)
# Test right padding
tokenizer
.
padding_side
=
"right"
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
,
return_special_tokens_mask
=
True
,
)
padded_input_ids
=
padded_sequence
[
"input_ids"
]
padded_token_type_ids
=
padded_sequence
[
"token_type_ids"
]
padded_attention_mask
=
padded_sequence
[
"attention_mask"
]
padded_special_tokens_mask
=
padded_sequence
[
"special_tokens_mask"
]
padded_sequence_length
=
len
(
padded_input_ids
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
input_ids
+
[
padding_idx
]
*
padding_size
==
padded_input_ids
assert
token_type_ids
+
[
token_type_padding_idx
]
*
padding_size
==
padded_token_type_ids
assert
attention_mask
+
[
0
]
*
padding_size
==
padded_attention_mask
assert
special_tokens_mask
+
[
1
]
*
padding_size
==
padded_special_tokens_mask
# Test left padding
tokenizer
.
padding_side
=
"left"
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
,
return_special_tokens_mask
=
True
,
)
padded_input_ids
=
padded_sequence
[
"input_ids"
]
padded_token_type_ids
=
padded_sequence
[
"token_type_ids"
]
padded_attention_mask
=
padded_sequence
[
"attention_mask"
]
padded_special_tokens_mask
=
padded_sequence
[
"special_tokens_mask"
]
padded_sequence_length
=
len
(
padded_input_ids
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
[
padding_idx
]
*
padding_size
+
input_ids
==
padded_input_ids
assert
[
token_type_padding_idx
]
*
padding_size
+
token_type_ids
==
padded_token_type_ids
assert
[
0
]
*
padding_size
+
attention_mask
==
padded_attention_mask
assert
[
1
]
*
padding_size
+
special_tokens_mask
==
padded_special_tokens_mask
t
ransformer
s/test
s/
tokenization_ctrl
_test
.py
→
t
est
s/test
_
tokenization_ctrl.py
View file @
5daca95d
...
@@ -20,10 +20,10 @@ from io import open
...
@@ -20,10 +20,10 @@ from io import open
from
transformers.tokenization_ctrl
import
VOCAB_FILES_NAMES
,
CTRLTokenizer
from
transformers.tokenization_ctrl
import
VOCAB_FILES_NAMES
,
CTRLTokenizer
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
class
CTRLTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
CTRLTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
CTRLTokenizer
tokenizer_class
=
CTRLTokenizer
...
@@ -63,7 +63,3 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -63,7 +63,3 @@ class CTRLTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_bpe_tokens
=
[
0
,
1
,
2
,
4
,
5
,
1
,
0
,
3
,
6
]
input_bpe_tokens
=
[
0
,
1
,
2
,
4
,
5
,
1
,
0
,
3
,
6
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_distilbert
_test
.py
→
t
est
s/test
_
tokenization_distilbert.py
View file @
5daca95d
...
@@ -14,11 +14,9 @@
...
@@ -14,11 +14,9 @@
# limitations under the License.
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
unittest
from
transformers.tokenization_distilbert
import
DistilBertTokenizer
from
transformers.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_bert
_test
import
BertTokenizationTest
from
.
test_
tokenization_bert
import
BertTokenizationTest
from
.utils
import
slow
from
.utils
import
slow
...
@@ -43,7 +41,3 @@ class DistilBertTokenizationTest(BertTokenizationTest):
...
@@ -43,7 +41,3 @@ class DistilBertTokenizationTest(BertTokenizationTest):
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
tokenizer
.
sep_token_id
tokenizer
.
sep_token_id
]
]
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_gpt2
_test
.py
→
t
est
s/test
_
tokenization_gpt2.py
View file @
5daca95d
...
@@ -21,10 +21,10 @@ from io import open
...
@@ -21,10 +21,10 @@ from io import open
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
from
transformers.tokenization_gpt2
import
VOCAB_FILES_NAMES
,
GPT2Tokenizer
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
class
GPT2TokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
GPT2TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
GPT2Tokenizer
tokenizer_class
=
GPT2Tokenizer
...
@@ -84,7 +84,3 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -84,7 +84,3 @@ class GPT2TokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
input_bpe_tokens
=
[
14
,
15
,
10
,
9
,
3
,
2
,
15
,
19
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_openai
_test
.py
→
t
est
s/test
_
tokenization_openai.py
View file @
5daca95d
...
@@ -20,10 +20,10 @@ import unittest
...
@@ -20,10 +20,10 @@ import unittest
from
transformers.tokenization_openai
import
VOCAB_FILES_NAMES
,
OpenAIGPTTokenizer
from
transformers.tokenization_openai
import
VOCAB_FILES_NAMES
,
OpenAIGPTTokenizer
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
class
OpenAIGPTTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
OpenAIGPTTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
OpenAIGPTTokenizer
tokenizer_class
=
OpenAIGPTTokenizer
...
@@ -83,7 +83,3 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -83,7 +83,3 @@ class OpenAIGPTTokenizationTest(CommonTestCases.CommonTokenizerTester):
input_tokens
=
tokens
+
[
"<unk>"
]
input_tokens
=
tokens
+
[
"<unk>"
]
input_bpe_tokens
=
[
14
,
15
,
20
]
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_roberta
_test
.py
→
t
est
s/test
_
tokenization_roberta.py
View file @
5daca95d
...
@@ -21,11 +21,11 @@ from io import open
...
@@ -21,11 +21,11 @@ from io import open
from
transformers.tokenization_roberta
import
VOCAB_FILES_NAMES
,
RobertaTokenizer
from
transformers.tokenization_roberta
import
VOCAB_FILES_NAMES
,
RobertaTokenizer
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
from
.utils
import
slow
from
.utils
import
slow
class
RobertaTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
RobertaTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
RobertaTokenizer
tokenizer_class
=
RobertaTokenizer
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -111,7 +111,3 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -111,7 +111,3 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
assert
encoded_sentence
==
encoded_text_from_decode
assert
encoded_sentence
==
encoded_text_from_decode
assert
encoded_pair
==
encoded_pair_from_decode
assert
encoded_pair
==
encoded_pair_from_decode
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_t5
_test
.py
→
t
est
s/test
_
tokenization_t5.py
View file @
5daca95d
...
@@ -20,13 +20,13 @@ import unittest
...
@@ -20,13 +20,13 @@ import unittest
from
transformers.tokenization_t5
import
T5Tokenizer
from
transformers.tokenization_t5
import
T5Tokenizer
from
transformers.tokenization_xlnet
import
SPIECE_UNDERLINE
from
transformers.tokenization_xlnet
import
SPIECE_UNDERLINE
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/test_sentencepiece.model"
)
class
T5TokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
T5TokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
T5Tokenizer
tokenizer_class
=
T5Tokenizer
...
@@ -110,7 +110,3 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -110,7 +110,3 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
"."
,
"."
,
],
],
)
)
if
__name__
==
"__main__"
:
unittest
.
main
()
t
ransformer
s/test
s/
tokenization_transfo_xl
_test
.py
→
t
est
s/test
_
tokenization_transfo_xl.py
View file @
5daca95d
...
@@ -20,7 +20,7 @@ from io import open
...
@@ -20,7 +20,7 @@ from io import open
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
.tokenization_
tests_
common
s
import
CommonTestCases
from
.
test_
tokenization_common
import
TokenizerTesterMixin
from
.utils
import
require_torch
from
.utils
import
require_torch
...
@@ -29,7 +29,7 @@ if is_torch_available():
...
@@ -29,7 +29,7 @@ if is_torch_available():
@
require_torch
@
require_torch
class
TransfoXLTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
TransfoXLTokenizationTest
(
TokenizerTesterMixin
,
unittest
.
TestCase
):
tokenizer_class
=
TransfoXLTokenizer
if
is_torch_available
()
else
None
tokenizer_class
=
TransfoXLTokenizer
if
is_torch_available
()
else
None
...
@@ -83,7 +83,3 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -83,7 +83,3 @@ class TransfoXLTokenizationTest(CommonTestCases.CommonTokenizerTester):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
tokenize
(
"
\t
HeLLo ! how
\n
Are yoU ? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
]
tokenizer
.
tokenize
(
"
\t
HeLLo ! how
\n
Are yoU ? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
]
)
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment