Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0558c9cb
Commit
0558c9cb
authored
Dec 10, 2019
by
thomwolf
Browse files
Merge branch 'master' into t5
parents
608a8f5b
e57d00ee
Changes
168
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
950 additions
and
64 deletions
+950
-64
transformers/tests/modeling_xlnet_test.py
transformers/tests/modeling_xlnet_test.py
+80
-15
transformers/tests/optimization_test.py
transformers/tests/optimization_test.py
+20
-15
transformers/tests/optimization_tf_test.py
transformers/tests/optimization_tf_test.py
+90
-0
transformers/tests/tokenization_albert_test.py
transformers/tests/tokenization_albert_test.py
+78
-0
transformers/tests/tokenization_auto_test.py
transformers/tests/tokenization_auto_test.py
+3
-2
transformers/tests/tokenization_bert_test.py
transformers/tests/tokenization_bert_test.py
+2
-2
transformers/tests/tokenization_distilbert_test.py
transformers/tests/tokenization_distilbert_test.py
+2
-2
transformers/tests/tokenization_roberta_test.py
transformers/tests/tokenization_roberta_test.py
+2
-2
transformers/tests/tokenization_tests_commons.py
transformers/tests/tokenization_tests_commons.py
+158
-8
transformers/tests/tokenization_transfo_xl_test.py
transformers/tests/tokenization_transfo_xl_test.py
+3
-3
transformers/tests/tokenization_utils_test.py
transformers/tests/tokenization_utils_test.py
+4
-2
transformers/tests/tokenization_xlm_test.py
transformers/tests/tokenization_xlm_test.py
+2
-2
transformers/tests/tokenization_xlnet_test.py
transformers/tests/tokenization_xlnet_test.py
+2
-2
transformers/tests/utils.py
transformers/tests/utils.py
+64
-0
transformers/tokenization_albert.py
transformers/tokenization_albert.py
+252
-0
transformers/tokenization_auto.py
transformers/tokenization_auto.py
+17
-4
transformers/tokenization_bert.py
transformers/tokenization_bert.py
+1
-1
transformers/tokenization_camembert.py
transformers/tokenization_camembert.py
+160
-0
transformers/tokenization_ctrl.py
transformers/tokenization_ctrl.py
+6
-4
transformers/tokenization_distilbert.py
transformers/tokenization_distilbert.py
+4
-0
No files found.
transformers/tests/modeling_xlnet_test.py
View file @
0558c9cb
...
@@ -21,24 +21,25 @@ import unittest
...
@@ -21,24 +21,25 @@ import unittest
import
json
import
json
import
random
import
random
import
shutil
import
shutil
import
pytest
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
from
transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForTokenClassification
,
XLNetForQuestionAnswering
)
from
transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from
transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.modeling_common_test
import
(
CommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
.configuration_common_test
import
ConfigTester
from
.utils
import
require_torch
,
slow
,
torch_device
@
require_torch
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForTokenClassification
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
if
is_torch_available
()
else
()
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
if
is_torch_available
()
else
()
test_pruning
=
False
test_pruning
=
False
...
@@ -99,18 +100,20 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -99,18 +100,20 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
,
device
=
torch_device
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
sequence_labels
=
None
sequence_labels
=
None
lm_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
is_impossible_labels
=
None
token_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
XLNetConfig
(
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
@@ -129,15 +132,16 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -129,15 +132,16 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
num_labels
=
self
.
type_sequence_label_size
)
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
)
def
set_seed
(
self
):
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetModel
(
config
)
model
=
XLNetModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
input_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
input_mask
=
input_mask
)
...
@@ -152,6 +156,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -152,6 +156,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config
.
mem_len
=
0
config
.
mem_len
=
0
model
=
XLNetModel
(
config
)
model
=
XLNetModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
no_mems_outputs
=
model
(
input_ids_1
)
no_mems_outputs
=
model
(
input_ids_1
)
self
.
parent
.
assertEqual
(
len
(
no_mems_outputs
),
1
)
self
.
parent
.
assertEqual
(
len
(
no_mems_outputs
),
1
)
...
@@ -163,9 +168,23 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -163,9 +168,23 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_base_model_with_att_output
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
_
,
_
,
attentions
=
model
(
input_ids_1
,
target_mapping
=
target_mapping
)
self
.
parent
.
assertEqual
(
len
(
attentions
),
config
.
n_layer
)
self
.
parent
.
assertIsInstance
(
attentions
[
0
],
tuple
)
self
.
parent
.
assertEqual
(
len
(
attentions
[
0
]),
2
)
self
.
parent
.
assertTrue
(
attentions
[
0
][
0
].
shape
,
attentions
[
0
][
0
].
shape
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
=
XLNetLMHeadModel
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
...
@@ -204,8 +223,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -204,8 +223,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetForQuestionAnswering
(
config
)
model
=
XLNetForQuestionAnswering
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
outputs
=
model
(
input_ids_1
)
outputs
=
model
(
input_ids_1
)
...
@@ -261,9 +281,43 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -261,9 +281,43 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_token_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetForTokenClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
loss
,
logits
,
mems_1
=
model
(
input_ids_1
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"mems_1"
:
mems_1
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetForSequenceClassification
(
config
)
model
=
XLNetForSequenceClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
logits
,
mems_1
=
model
(
input_ids_1
)
...
@@ -289,7 +343,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -289,7 +343,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
sequence_labels
,
is_impossible_labels
,
token_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
return
config
,
inputs_dict
...
@@ -306,22 +360,33 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
...
@@ -306,22 +360,33 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
def
test_xlnet_base_model_with_att_output
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
[
0
].
output_attentions
=
True
self
.
model_tester
.
create_and_check_xlnet_base_model_with_att_output
(
*
config_and_inputs
)
def
test_xlnet_lm_head
(
self
):
def
test_xlnet_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
def
test_xlnet_sequence_classif
(
self
):
def
test_xlnet_sequence_classif
(
self
):
self
.
model_tester
.
set_seed
()
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_token_classif
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_token_classif
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
list
(
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
...
...
transformers/tests/optimization_test.py
View file @
0558c9cb
...
@@ -18,19 +18,21 @@ from __future__ import print_function
...
@@ -18,19 +18,21 @@ from __future__ import print_function
import
unittest
import
unittest
import
os
import
os
import
pytest
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
from
transformers
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
from
transformers
import
(
AdamW
,
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
get_constant_schedule
,
else
:
get_constant_schedule_with_warmup
,
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
get_cosine_schedule_with_warmup
,
get_cosine_with_hard_restarts_schedule_with_warmup
,
get_linear_schedule_with_warmup
)
from
.tokenization_tests_commons
import
TemporaryDirectory
from
.tokenization_tests_commons
import
TemporaryDirectory
from
.utils
import
require_torch
def
unwrap_schedule
(
scheduler
,
num_steps
=
10
):
def
unwrap_schedule
(
scheduler
,
num_steps
=
10
):
...
@@ -54,6 +56,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
...
@@ -54,6 +56,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
scheduler
.
load_state_dict
(
state_dict
)
scheduler
.
load_state_dict
(
state_dict
)
return
lrs
return
lrs
@
require_torch
class
OptimizationTest
(
unittest
.
TestCase
):
class
OptimizationTest
(
unittest
.
TestCase
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
...
@@ -76,6 +79,7 @@ class OptimizationTest(unittest.TestCase):
...
@@ -76,6 +79,7 @@ class OptimizationTest(unittest.TestCase):
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
@
require_torch
class
ScheduleInitTest
(
unittest
.
TestCase
):
class
ScheduleInitTest
(
unittest
.
TestCase
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
if
is_torch_available
()
else
None
m
=
torch
.
nn
.
Linear
(
50
,
50
)
if
is_torch_available
()
else
None
optimizer
=
AdamW
(
m
.
parameters
(),
lr
=
10.
)
if
is_torch_available
()
else
None
optimizer
=
AdamW
(
m
.
parameters
(),
lr
=
10.
)
if
is_torch_available
()
else
None
...
@@ -87,59 +91,60 @@ class ScheduleInitTest(unittest.TestCase):
...
@@ -87,59 +91,60 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
def
test_constant_scheduler
(
self
):
def
test_constant_scheduler
(
self
):
scheduler
=
C
onstant
LRS
chedule
(
self
.
optimizer
)
scheduler
=
get_c
onstant
_s
chedule
(
self
.
optimizer
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
10.
]
*
self
.
num_steps
expected_learning_rates
=
[
10.
]
*
self
.
num_steps
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
scheduler
=
C
onstant
LRS
chedule
(
self
.
optimizer
)
scheduler
=
get_c
onstant
_s
chedule
(
self
.
optimizer
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_constant_scheduler
(
self
):
def
test_warmup_constant_scheduler
(
self
):
scheduler
=
WarmupC
onstant
S
chedule
(
self
.
optimizer
,
warmup_steps
=
4
)
scheduler
=
get_c
onstant
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
4
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
2.5
,
5.0
,
7.5
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
]
expected_learning_rates
=
[
2.5
,
5.0
,
7.5
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
scheduler
=
WarmupC
onstant
S
chedule
(
self
.
optimizer
,
warmup_steps
=
4
)
scheduler
=
get_c
onstant
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
4
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_linear_scheduler
(
self
):
def
test_warmup_linear_scheduler
(
self
):
scheduler
=
WarmupL
inear
S
chedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
scheduler
=
get_l
inear
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
2
,
num_training_steps
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
5.0
,
10.0
,
8.75
,
7.5
,
6.25
,
5.0
,
3.75
,
2.5
,
1.25
,
0.0
]
expected_learning_rates
=
[
5.0
,
10.0
,
8.75
,
7.5
,
6.25
,
5.0
,
3.75
,
2.5
,
1.25
,
0.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
scheduler
=
WarmupL
inear
S
chedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
scheduler
=
get_l
inear
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
2
,
num_training_steps
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_cosine_scheduler
(
self
):
def
test_warmup_cosine_scheduler
(
self
):
scheduler
=
WarmupC
osine
S
chedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
scheduler
=
get_c
osine
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
2
,
num_training_steps
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
5.0
,
10.0
,
9.61
,
8.53
,
6.91
,
5.0
,
3.08
,
1.46
,
0.38
,
0.0
]
expected_learning_rates
=
[
5.0
,
10.0
,
9.61
,
8.53
,
6.91
,
5.0
,
3.08
,
1.46
,
0.38
,
0.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
scheduler
=
WarmupC
osine
S
chedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
scheduler
=
get_c
osine
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
2
,
num_training_steps
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_cosine_hard_restart_scheduler
(
self
):
def
test_warmup_cosine_hard_restart_scheduler
(
self
):
scheduler
=
WarmupC
osine
W
ith
H
ard
R
estarts
S
chedule
(
self
.
optimizer
,
warmup_steps
=
2
,
cycles
=
2
,
t_total
=
10
)
scheduler
=
get_c
osine
_w
ith
_h
ard
_r
estarts
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
2
,
num_
cycles
=
2
,
num_training_steps
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
5.0
,
10.0
,
8.53
,
5.0
,
1.46
,
10.0
,
8.53
,
5.0
,
1.46
,
0.0
]
expected_learning_rates
=
[
5.0
,
10.0
,
8.53
,
5.0
,
1.46
,
10.0
,
8.53
,
5.0
,
1.46
,
0.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
scheduler
=
WarmupC
osine
W
ith
H
ard
R
estarts
S
chedule
(
self
.
optimizer
,
warmup_steps
=
2
,
cycles
=
2
,
t_total
=
10
)
scheduler
=
get_c
osine
_w
ith
_h
ard
_r
estarts
_s
chedule
_with_warmup
(
self
.
optimizer
,
num_
warmup_steps
=
2
,
num_
cycles
=
2
,
num_training_steps
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
transformers/tests/optimization_tf_test.py
0 → 100644
View file @
0558c9cb
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
from
transformers
import
is_tf_available
from
.utils
import
require_tf
if
is_tf_available
():
import
tensorflow
as
tf
from
tensorflow.python.eager
import
context
from
tensorflow.python.framework
import
ops
from
transformers
import
(
create_optimizer
,
GradientAccumulator
)
@
require_tf
class
OptimizationFTest
(
unittest
.
TestCase
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
self
.
assertEqual
(
len
(
list1
),
len
(
list2
))
for
a
,
b
in
zip
(
list1
,
list2
):
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
def
testGradientAccumulator
(
self
):
accumulator
=
GradientAccumulator
()
accumulator
([
tf
.
constant
([
1.0
,
2.0
])])
accumulator
([
tf
.
constant
([
-
2.0
,
1.0
])])
accumulator
([
tf
.
constant
([
-
1.0
,
2.0
])])
with
self
.
assertRaises
(
ValueError
):
accumulator
([
tf
.
constant
([
1.0
,
1.0
]),
tf
.
constant
([
2.0
,
2.0
])])
self
.
assertEqual
(
accumulator
.
step
,
3
)
self
.
assertEqual
(
len
(
accumulator
.
gradients
),
1
)
self
.
assertListAlmostEqual
(
accumulator
.
gradients
[
0
].
numpy
().
tolist
(),
[
-
2.0
,
5.0
],
tol
=
1e-2
)
accumulator
.
reset
()
self
.
assertEqual
(
accumulator
.
step
,
0
)
self
.
assertListAlmostEqual
(
accumulator
.
gradients
[
0
].
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
def
testGradientAccumulatorDistributionStrategy
(
self
):
context
.
_context
=
None
ops
.
enable_eager_execution_internal
()
physical_devices
=
tf
.
config
.
experimental
.
list_physical_devices
(
"CPU"
)
tf
.
config
.
experimental
.
set_virtual_device_configuration
(
physical_devices
[
0
],
[
tf
.
config
.
experimental
.
VirtualDeviceConfiguration
(),
tf
.
config
.
experimental
.
VirtualDeviceConfiguration
()])
devices
=
tf
.
config
.
experimental
.
list_logical_devices
(
device_type
=
"CPU"
)
strategy
=
tf
.
distribute
.
MirroredStrategy
(
devices
=
[
device
.
name
for
device
in
devices
])
with
strategy
.
scope
():
accumulator
=
GradientAccumulator
()
variable
=
tf
.
Variable
([
4.0
,
3.0
])
optimizer
=
create_optimizer
(
5e-5
,
10
,
5
)
gradient_placeholder
=
tf
.
Variable
([
0.0
,
0.0
],
trainable
=
False
)
def
accumulate_on_replica
(
gradient
):
accumulator
([
gradient
])
def
apply_on_replica
():
optimizer
.
apply_gradients
(
list
(
zip
(
accumulator
.
gradients
,
[
variable
])),
1.0
)
@
tf
.
function
def
accumulate
(
grad1
,
grad2
):
with
strategy
.
scope
():
gradient_placeholder
.
values
[
0
].
assign
(
grad1
)
gradient_placeholder
.
values
[
1
].
assign
(
grad2
)
strategy
.
experimental_run_v2
(
accumulate_on_replica
,
args
=
(
gradient_placeholder
,))
@
tf
.
function
def
apply_grad
():
with
strategy
.
scope
():
strategy
.
experimental_run_v2
(
apply_on_replica
)
accumulate
([
1.0
,
2.0
],
[
-
1.0
,
1.0
])
accumulate
([
3.0
,
-
1.0
],
[
-
1.0
,
-
1.0
])
accumulate
([
-
2.0
,
2.0
],
[
3.0
,
-
2.0
])
self
.
assertEqual
(
accumulator
.
step
,
3
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
0
].
value
().
numpy
().
tolist
(),
[
2.0
,
3.0
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
1
].
value
().
numpy
().
tolist
(),
[
1.0
,
-
2.0
],
tol
=
1e-2
)
apply_grad
()
self
.
assertListAlmostEqual
(
variable
.
value
().
numpy
().
tolist
(),
[
4.0
,
3.0
],
tol
=
1e-2
)
accumulator
.
reset
()
self
.
assertEqual
(
accumulator
.
step
,
0
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
0
].
value
().
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
accumulator
.
_gradients
[
0
].
values
[
1
].
value
().
numpy
().
tolist
(),
[
0.0
,
0.0
],
tol
=
1e-2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
\ No newline at end of file
transformers/tests/tokenization_albert_test.py
0 → 100644
View file @
0558c9cb
# coding=utf-8
# Copyright 2019 Hugging Face inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
from
transformers.tokenization_albert
import
(
AlbertTokenizer
,
SPIECE_UNDERLINE
)
from
.tokenization_tests_commons
import
CommonTestCases
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/spiece.model'
)
class
AlbertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
tokenizer_class
=
AlbertTokenizer
def
setUp
(
self
):
super
(
AlbertTokenizationTest
,
self
).
setUp
()
# We have a SentencePiece fixture for testing
tokenizer
=
AlbertTokenizer
(
SAMPLE_VOCAB
)
tokenizer
.
save_pretrained
(
self
.
tmpdirname
)
def
get_tokenizer
(
self
,
**
kwargs
):
return
AlbertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
def
get_input_output_texts
(
self
):
input_text
=
u
"this is a test"
output_text
=
u
"this is a test"
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
tokenizer
=
AlbertTokenizer
(
SAMPLE_VOCAB
,
keep_accents
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
'This is a test'
)
self
.
assertListEqual
(
tokens
,
[
u
'▁this'
,
u
'▁is'
,
u
'▁a'
,
u
'▁test'
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
48
,
25
,
21
,
1289
])
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
u
'▁i'
,
u
'▁was'
,
u
'▁born'
,
u
'▁in'
,
u
'▁9'
,
u
'2000'
,
u
','
,
u
'▁and'
,
u
'▁this'
,
u
'▁is'
,
u
'▁fal'
,
u
's'
,
u
'é'
,
u
'.'
])
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
ids
,
[
31
,
23
,
386
,
19
,
561
,
3050
,
15
,
17
,
48
,
25
,
8256
,
18
,
1
,
9
])
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
self
.
assertListEqual
(
back_tokens
,
[
'▁i'
,
'▁was'
,
'▁born'
,
'▁in'
,
'▁9'
,
'2000'
,
','
,
'▁and'
,
'▁this'
,
'▁is'
,
'▁fal'
,
's'
,
'<unk>'
,
'.'
])
def
test_sequence_builders
(
self
):
tokenizer
=
AlbertTokenizer
(
SAMPLE_VOCAB
)
text
=
tokenizer
.
encode
(
"sequence builders"
)
text_2
=
tokenizer
.
encode
(
"multi-sequence build"
)
encoded_sentence
=
tokenizer
.
build_inputs_with_special_tokens
(
text
)
encoded_pair
=
tokenizer
.
build_inputs_with_special_tokens
(
text
,
text_2
)
assert
encoded_sentence
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
assert
encoded_pair
==
[
tokenizer
.
cls_token_id
]
+
text
+
[
tokenizer
.
sep_token_id
]
+
text_2
+
[
tokenizer
.
sep_token_id
]
if
__name__
==
'__main__'
:
unittest
.
main
()
transformers/tests/tokenization_auto_test.py
View file @
0558c9cb
...
@@ -18,15 +18,16 @@ from __future__ import print_function
...
@@ -18,15 +18,16 @@ from __future__ import print_function
import
unittest
import
unittest
import
shutil
import
shutil
import
pytest
import
logging
import
logging
from
transformers
import
AutoTokenizer
,
BertTokenizer
,
AutoTokenizer
,
GPT2Tokenizer
from
transformers
import
AutoTokenizer
,
BertTokenizer
,
AutoTokenizer
,
GPT2Tokenizer
from
transformers
import
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from
transformers
import
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from
.utils
import
slow
class
AutoTokenizerTest
(
unittest
.
TestCase
):
class
AutoTokenizerTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
slow
@
slow
def
test_tokenizer_from_pretrained
(
self
):
def
test_tokenizer_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
())[:
1
]:
for
model_name
in
list
(
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
.
keys
())[:
1
]:
...
...
transformers/tests/tokenization_bert_test.py
View file @
0558c9cb
...
@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
pytest
from
io
import
open
from
io
import
open
from
transformers.tokenization_bert
import
(
BasicTokenizer
,
from
transformers.tokenization_bert
import
(
BasicTokenizer
,
...
@@ -26,6 +25,7 @@ from transformers.tokenization_bert import (BasicTokenizer,
...
@@ -26,6 +25,7 @@ from transformers.tokenization_bert import (BasicTokenizer,
_is_whitespace
,
VOCAB_FILES_NAMES
)
_is_whitespace
,
VOCAB_FILES_NAMES
)
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
class
BertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
BertTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
...
@@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -126,7 +126,7 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
self
.
assertFalse
(
_is_punctuation
(
u
" "
))
@
pytest
.
mark
.
slow
@
slow
def
test_sequence_builders
(
self
):
def
test_sequence_builders
(
self
):
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
"bert-base-uncased"
)
tokenizer
=
self
.
tokenizer_class
.
from_pretrained
(
"bert-base-uncased"
)
...
...
transformers/tests/tokenization_distilbert_test.py
View file @
0558c9cb
...
@@ -16,13 +16,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -16,13 +16,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
pytest
from
io
import
open
from
io
import
open
from
transformers.tokenization_distilbert
import
(
DistilBertTokenizer
)
from
transformers.tokenization_distilbert
import
(
DistilBertTokenizer
)
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_bert_test
import
BertTokenizationTest
from
.tokenization_bert_test
import
BertTokenizationTest
from
.utils
import
slow
class
DistilBertTokenizationTest
(
BertTokenizationTest
):
class
DistilBertTokenizationTest
(
BertTokenizationTest
):
...
@@ -31,7 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
...
@@ -31,7 +31,7 @@ class DistilBertTokenizationTest(BertTokenizationTest):
def
get_tokenizer
(
self
,
**
kwargs
):
def
get_tokenizer
(
self
,
**
kwargs
):
return
DistilBertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
return
DistilBertTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
kwargs
)
@
pytest
.
mark
.
slow
@
slow
def
test_sequence_builders
(
self
):
def
test_sequence_builders
(
self
):
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
"distilbert-base-uncased"
)
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
"distilbert-base-uncased"
)
...
...
transformers/tests/tokenization_roberta_test.py
View file @
0558c9cb
...
@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
json
import
json
import
unittest
import
unittest
import
pytest
from
io
import
open
from
io
import
open
from
transformers.tokenization_roberta
import
RobertaTokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_roberta
import
RobertaTokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
class
RobertaTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
RobertaTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
...
@@ -79,7 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -79,7 +79,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]
[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]
)
)
@
pytest
.
mark
.
slow
@
slow
def
test_sequence_builders
(
self
):
def
test_sequence_builders
(
self
):
tokenizer
=
RobertaTokenizer
.
from_pretrained
(
"roberta-base"
)
tokenizer
=
RobertaTokenizer
.
from_pretrained
(
"roberta-base"
)
...
...
transformers/tests/tokenization_tests_commons.py
View file @
0558c9cb
...
@@ -102,14 +102,48 @@ class CommonTestCases:
...
@@ -102,14 +102,48 @@ class CommonTestCases:
with
TemporaryDirectory
()
as
tmpdirname
:
with
TemporaryDirectory
()
as
tmpdirname
:
filename
=
os
.
path
.
join
(
tmpdirname
,
u
"tokenizer.bin"
)
filename
=
os
.
path
.
join
(
tmpdirname
,
u
"tokenizer.bin"
)
pickle
.
dump
(
tokenizer
,
open
(
filename
,
"wb"
))
with
open
(
filename
,
"wb"
)
as
handle
:
pickle
.
dump
(
tokenizer
,
handle
)
tokenizer_new
=
pickle
.
load
(
open
(
filename
,
"rb"
))
with
open
(
filename
,
"rb"
)
as
handle
:
tokenizer_new
=
pickle
.
load
(
handle
)
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
self
.
assertListEqual
(
subwords
,
subwords_loaded
)
def
test_added_tokens_do_lower_case
(
self
):
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
True
)
special_token
=
tokenizer
.
all_special_tokens
[
0
]
text
=
special_token
+
" aaaaa bbbbbb low cccccccccdddddddd l "
+
special_token
text2
=
special_token
+
" AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l "
+
special_token
toks0
=
tokenizer
.
tokenize
(
text
)
# toks before adding new_toks
new_toks
=
[
"aaaaa bbbbbb"
,
"cccccccccdddddddd"
,
'AAAAA BBBBBB'
,
'CCCCCCCCCDDDDDDDD'
]
added
=
tokenizer
.
add_tokens
(
new_toks
)
self
.
assertEqual
(
added
,
2
)
toks
=
tokenizer
.
tokenize
(
text
)
toks2
=
tokenizer
.
tokenize
(
text2
)
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
# toks0 should be longer
self
.
assertListEqual
(
toks
,
toks2
)
tokenizer
=
self
.
get_tokenizer
(
do_lower_case
=
False
)
added
=
tokenizer
.
add_tokens
(
new_toks
)
self
.
assertEqual
(
added
,
4
)
toks
=
tokenizer
.
tokenize
(
text
)
toks2
=
tokenizer
.
tokenize
(
text2
)
self
.
assertEqual
(
len
(
toks
),
len
(
toks2
))
# Length should still be the same
self
.
assertNotEqual
(
len
(
toks
),
len
(
toks0
))
self
.
assertNotEqual
(
toks
[
1
],
toks2
[
1
])
# But at least the first non-special tokens should differ
def
test_add_tokens_tokenizer
(
self
):
def
test_add_tokens_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
@@ -160,6 +194,26 @@ class CommonTestCases:
...
@@ -160,6 +194,26 @@ class CommonTestCases:
self
.
assertEqual
(
tokens
[
0
],
tokenizer
.
eos_token_id
)
self
.
assertEqual
(
tokens
[
0
],
tokenizer
.
eos_token_id
)
self
.
assertEqual
(
tokens
[
-
2
],
tokenizer
.
pad_token_id
)
self
.
assertEqual
(
tokens
[
-
2
],
tokenizer
.
pad_token_id
)
def
test_add_special_tokens
(
self
):
tokenizer
=
self
.
get_tokenizer
()
input_text
,
output_text
=
self
.
get_input_output_texts
()
special_token
=
"[SPECIAL TOKEN]"
tokenizer
.
add_special_tokens
({
"cls_token"
:
special_token
})
encoded_special_token
=
tokenizer
.
encode
(
special_token
,
add_special_tokens
=
False
)
assert
len
(
encoded_special_token
)
==
1
text
=
" "
.
join
([
input_text
,
special_token
,
output_text
])
encoded
=
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
input_encoded
=
tokenizer
.
encode
(
input_text
,
add_special_tokens
=
False
)
output_encoded
=
tokenizer
.
encode
(
output_text
,
add_special_tokens
=
False
)
special_token_id
=
tokenizer
.
encode
(
special_token
,
add_special_tokens
=
False
)
assert
encoded
==
input_encoded
+
special_token_id
+
output_encoded
decoded
=
tokenizer
.
decode
(
encoded
,
skip_special_tokens
=
True
)
assert
special_token
not
in
decoded
def
test_required_methods_tokenizer
(
self
):
def
test_required_methods_tokenizer
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
@@ -223,7 +277,11 @@ class CommonTestCases:
...
@@ -223,7 +277,11 @@ class CommonTestCases:
sequence
=
tokenizer
.
encode
(
seq_0
,
add_special_tokens
=
False
)
sequence
=
tokenizer
.
encode
(
seq_0
,
add_special_tokens
=
False
)
num_added_tokens
=
tokenizer
.
num_added_tokens
()
num_added_tokens
=
tokenizer
.
num_added_tokens
()
total_length
=
len
(
sequence
)
+
num_added_tokens
total_length
=
len
(
sequence
)
+
num_added_tokens
information
=
tokenizer
.
encode_plus
(
seq_0
,
max_length
=
total_length
-
2
,
add_special_tokens
=
True
,
stride
=
stride
)
information
=
tokenizer
.
encode_plus
(
seq_0
,
max_length
=
total_length
-
2
,
add_special_tokens
=
True
,
stride
=
stride
,
return_overflowing_tokens
=
True
)
truncated_sequence
=
information
[
"input_ids"
]
truncated_sequence
=
information
[
"input_ids"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
...
@@ -250,10 +308,12 @@ class CommonTestCases:
...
@@ -250,10 +308,12 @@ class CommonTestCases:
)
)
information
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
add_special_tokens
=
True
,
information
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
add_special_tokens
=
True
,
stride
=
stride
,
truncation_strategy
=
'only_second'
)
stride
=
stride
,
truncation_strategy
=
'only_second'
,
return_overflowing_tokens
=
True
)
information_first_truncated
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
information_first_truncated
=
tokenizer
.
encode_plus
(
seq_0
,
seq_1
,
max_length
=
len
(
sequence
)
-
2
,
add_special_tokens
=
True
,
stride
=
stride
,
add_special_tokens
=
True
,
stride
=
stride
,
truncation_strategy
=
'only_first'
)
truncation_strategy
=
'only_first'
,
return_overflowing_tokens
=
True
)
truncated_sequence
=
information
[
"input_ids"
]
truncated_sequence
=
information
[
"input_ids"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
overflowing_tokens
=
information
[
"overflowing_tokens"
]
...
@@ -285,7 +345,7 @@ class CommonTestCases:
...
@@ -285,7 +345,7 @@ class CommonTestCases:
# Testing single inputs
# Testing single inputs
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
add_special_tokens
=
True
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
...
@@ -297,7 +357,8 @@ class CommonTestCases:
...
@@ -297,7 +357,8 @@ class CommonTestCases:
# Testing inputs pairs
# Testing inputs pairs
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
+
tokenizer
.
encode
(
sequence_1
,
encoded_sequence
=
tokenizer
.
encode
(
sequence_0
,
add_special_tokens
=
False
)
+
tokenizer
.
encode
(
sequence_1
,
add_special_tokens
=
False
)
add_special_tokens
=
False
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
sequence_1
,
add_special_tokens
=
True
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
sequence_1
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
special_tokens_mask
=
encoded_sequence_dict
[
"special_tokens_mask"
]
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
...
@@ -309,9 +370,98 @@ class CommonTestCases:
...
@@ -309,9 +370,98 @@ class CommonTestCases:
# Testing with already existing special tokens
# Testing with already existing special tokens
if
tokenizer
.
cls_token_id
==
tokenizer
.
unk_token_id
and
tokenizer
.
cls_token_id
==
tokenizer
.
unk_token_id
:
if
tokenizer
.
cls_token_id
==
tokenizer
.
unk_token_id
and
tokenizer
.
cls_token_id
==
tokenizer
.
unk_token_id
:
tokenizer
.
add_special_tokens
({
'cls_token'
:
'</s>'
,
'sep_token'
:
'<s>'
})
tokenizer
.
add_special_tokens
({
'cls_token'
:
'</s>'
,
'sep_token'
:
'<s>'
})
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
add_special_tokens
=
True
)
encoded_sequence_dict
=
tokenizer
.
encode_plus
(
sequence_0
,
add_special_tokens
=
True
,
return_special_tokens_mask
=
True
)
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
encoded_sequence_w_special
=
encoded_sequence_dict
[
"input_ids"
]
special_tokens_mask_orig
=
encoded_sequence_dict
[
"special_tokens_mask"
]
special_tokens_mask_orig
=
encoded_sequence_dict
[
"special_tokens_mask"
]
special_tokens_mask
=
tokenizer
.
get_special_tokens_mask
(
encoded_sequence_w_special
,
already_has_special_tokens
=
True
)
special_tokens_mask
=
tokenizer
.
get_special_tokens_mask
(
encoded_sequence_w_special
,
already_has_special_tokens
=
True
)
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
self
.
assertEqual
(
len
(
special_tokens_mask
),
len
(
encoded_sequence_w_special
))
self
.
assertEqual
(
special_tokens_mask_orig
,
special_tokens_mask
)
self
.
assertEqual
(
special_tokens_mask_orig
,
special_tokens_mask
)
def
test_padding_to_max_length
(
self
):
tokenizer
=
self
.
get_tokenizer
()
sequence
=
"Sequence"
padding_size
=
10
padding_idx
=
tokenizer
.
pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer
.
padding_side
=
"right"
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
)
padded_sequence_length
=
len
(
padded_sequence
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
encoded_sequence
+
[
padding_idx
]
*
padding_size
==
padded_sequence
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer
.
padding_side
=
"left"
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
padded_sequence
=
tokenizer
.
encode
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
)
padded_sequence_length
=
len
(
padded_sequence
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
[
padding_idx
]
*
padding_size
+
encoded_sequence
==
padded_sequence
# RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
encoded_sequence
=
tokenizer
.
encode
(
sequence
)
sequence_length
=
len
(
encoded_sequence
)
tokenizer
.
padding_side
=
"right"
padded_sequence_right
=
tokenizer
.
encode
(
sequence
,
pad_to_max_length
=
True
)
padded_sequence_right_length
=
len
(
padded_sequence_right
)
tokenizer
.
padding_side
=
"left"
padded_sequence_left
=
tokenizer
.
encode
(
sequence
,
pad_to_max_length
=
True
)
padded_sequence_left_length
=
len
(
padded_sequence_left
)
assert
sequence_length
==
padded_sequence_right_length
assert
encoded_sequence
==
padded_sequence_right
assert
sequence_length
==
padded_sequence_left_length
assert
encoded_sequence
==
padded_sequence_left
def
test_encode_plus_with_padding
(
self
):
tokenizer
=
self
.
get_tokenizer
()
sequence
=
"Sequence"
padding_size
=
10
padding_idx
=
tokenizer
.
pad_token_id
token_type_padding_idx
=
tokenizer
.
pad_token_type_id
encoded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
return_special_tokens_mask
=
True
)
input_ids
=
encoded_sequence
[
'input_ids'
]
token_type_ids
=
encoded_sequence
[
'token_type_ids'
]
attention_mask
=
encoded_sequence
[
'attention_mask'
]
special_tokens_mask
=
encoded_sequence
[
'special_tokens_mask'
]
sequence_length
=
len
(
input_ids
)
# Test right padding
tokenizer
.
padding_side
=
"right"
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
,
return_special_tokens_mask
=
True
)
padded_input_ids
=
padded_sequence
[
'input_ids'
]
padded_token_type_ids
=
padded_sequence
[
'token_type_ids'
]
padded_attention_mask
=
padded_sequence
[
'attention_mask'
]
padded_special_tokens_mask
=
padded_sequence
[
'special_tokens_mask'
]
padded_sequence_length
=
len
(
padded_input_ids
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
input_ids
+
[
padding_idx
]
*
padding_size
==
padded_input_ids
assert
token_type_ids
+
[
token_type_padding_idx
]
*
padding_size
==
padded_token_type_ids
assert
attention_mask
+
[
0
]
*
padding_size
==
padded_attention_mask
assert
special_tokens_mask
+
[
1
]
*
padding_size
==
padded_special_tokens_mask
# Test left padding
tokenizer
.
padding_side
=
"left"
padded_sequence
=
tokenizer
.
encode_plus
(
sequence
,
max_length
=
sequence_length
+
padding_size
,
pad_to_max_length
=
True
,
return_special_tokens_mask
=
True
)
padded_input_ids
=
padded_sequence
[
'input_ids'
]
padded_token_type_ids
=
padded_sequence
[
'token_type_ids'
]
padded_attention_mask
=
padded_sequence
[
'attention_mask'
]
padded_special_tokens_mask
=
padded_sequence
[
'special_tokens_mask'
]
padded_sequence_length
=
len
(
padded_input_ids
)
assert
sequence_length
+
padding_size
==
padded_sequence_length
assert
[
padding_idx
]
*
padding_size
+
input_ids
==
padded_input_ids
assert
[
token_type_padding_idx
]
*
padding_size
+
token_type_ids
==
padded_token_type_ids
assert
[
0
]
*
padding_size
+
attention_mask
==
padded_attention_mask
assert
[
1
]
*
padding_size
+
special_tokens_mask
==
padded_special_tokens_mask
\ No newline at end of file
transformers/tests/tokenization_transfo_xl_test.py
View file @
0558c9cb
...
@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -16,7 +16,6 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
pytest
from
io
import
open
from
io
import
open
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
...
@@ -24,11 +23,12 @@ from transformers import is_torch_available
...
@@ -24,11 +23,12 @@ from transformers import is_torch_available
if
is_torch_available
():
if
is_torch_available
():
import
torch
import
torch
from
transformers.tokenization_transfo_xl
import
TransfoXLTokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_transfo_xl
import
TransfoXLTokenizer
,
VOCAB_FILES_NAMES
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
# TODO: untangle Transfo-XL tokenizer from torch.load and torch.save
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
require_torch
@
require_torch
class
TransfoXLTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
TransfoXLTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
tokenizer_class
=
TransfoXLTokenizer
if
is_torch_available
()
else
None
tokenizer_class
=
TransfoXLTokenizer
if
is_torch_available
()
else
None
...
...
transformers/tests/tokenization_utils_test.py
View file @
0558c9cb
...
@@ -18,13 +18,14 @@ from __future__ import print_function
...
@@ -18,13 +18,14 @@ from __future__ import print_function
import
unittest
import
unittest
import
six
import
six
import
pytest
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
transformers.tokenization_gpt2
import
GPT2Tokenizer
from
transformers.tokenization_gpt2
import
GPT2Tokenizer
from
.utils
import
slow
class
TokenizerUtilsTest
(
unittest
.
TestCase
):
class
TokenizerUtilsTest
(
unittest
.
TestCase
):
@
pytest
.
mark
.
slow
def
check_tokenizer_from_pretrained
(
self
,
tokenizer_class
):
def
check_tokenizer_from_pretrained
(
self
,
tokenizer_class
):
s3_models
=
list
(
tokenizer_class
.
max_model_input_sizes
.
keys
())
s3_models
=
list
(
tokenizer_class
.
max_model_input_sizes
.
keys
())
for
model_name
in
s3_models
[:
1
]:
for
model_name
in
s3_models
[:
1
]:
...
@@ -41,6 +42,7 @@ class TokenizerUtilsTest(unittest.TestCase):
...
@@ -41,6 +42,7 @@ class TokenizerUtilsTest(unittest.TestCase):
special_tok_id
=
tokenizer
.
convert_tokens_to_ids
(
special_tok
)
special_tok_id
=
tokenizer
.
convert_tokens_to_ids
(
special_tok
)
self
.
assertIsInstance
(
special_tok_id
,
int
)
self
.
assertIsInstance
(
special_tok_id
,
int
)
@
slow
def
test_pretrained_tokenizers
(
self
):
def
test_pretrained_tokenizers
(
self
):
self
.
check_tokenizer_from_pretrained
(
GPT2Tokenizer
)
self
.
check_tokenizer_from_pretrained
(
GPT2Tokenizer
)
...
...
transformers/tests/tokenization_xlm_test.py
View file @
0558c9cb
...
@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
json
import
json
import
pytest
from
transformers.tokenization_xlm
import
XLMTokenizer
,
VOCAB_FILES_NAMES
from
transformers.tokenization_xlm
import
XLMTokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
class
XLMTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
class
XLMTokenizationTest
(
CommonTestCases
.
CommonTokenizerTester
):
...
@@ -67,7 +67,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -67,7 +67,7 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
@
pytest
.
mark
.
slow
@
slow
def
test_sequence_builders
(
self
):
def
test_sequence_builders
(
self
):
tokenizer
=
XLMTokenizer
.
from_pretrained
(
"xlm-mlm-en-2048"
)
tokenizer
=
XLMTokenizer
.
from_pretrained
(
"xlm-mlm-en-2048"
)
...
...
transformers/tests/tokenization_xlnet_test.py
View file @
0558c9cb
...
@@ -16,11 +16,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -16,11 +16,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
pytest
from
transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
SPIECE_UNDERLINE
)
from
transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
SPIECE_UNDERLINE
)
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
from
.utils
import
slow
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/test_sentencepiece.model'
)
'fixtures/test_sentencepiece.model'
)
...
@@ -90,7 +90,7 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -90,7 +90,7 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
u
'9'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
u
'9'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
'se'
,
u
'.'
])
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
'se'
,
u
'.'
])
@
pytest
.
mark
.
slow
@
slow
def
test_sequence_builders
(
self
):
def
test_sequence_builders
(
self
):
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
"xlnet-base-cased"
)
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
"xlnet-base-cased"
)
...
...
transformers/tests/utils.py
0 → 100644
View file @
0558c9cb
import
os
import
unittest
from
distutils.util
import
strtobool
from
transformers.file_utils
import
_tf_available
,
_torch_available
try
:
run_slow
=
os
.
environ
[
"RUN_SLOW"
]
except
KeyError
:
# RUN_SLOW isn't set, default to skipping slow tests.
_run_slow_tests
=
False
else
:
# RUN_SLOW is set, convert it to True or False.
try
:
_run_slow_tests
=
strtobool
(
run_slow
)
except
ValueError
:
# More values are supported, but let's keep the message simple.
raise
ValueError
(
"If set, RUN_SLOW must be yes or no."
)
def
slow
(
test_case
):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable
to a truthy value to run them.
"""
if
not
_run_slow_tests
:
test_case
=
unittest
.
skip
(
"test is slow"
)(
test_case
)
return
test_case
def
require_torch
(
test_case
):
"""
Decorator marking a test that requires PyTorch.
These tests are skipped when PyTorch isn't installed.
"""
if
not
_torch_available
:
test_case
=
unittest
.
skip
(
"test requires PyTorch"
)(
test_case
)
return
test_case
def
require_tf
(
test_case
):
"""
Decorator marking a test that requires TensorFlow.
These tests are skipped when TensorFlow isn't installed.
"""
if
not
_tf_available
:
test_case
=
unittest
.
skip
(
"test requires TensorFlow"
)(
test_case
)
return
test_case
if
_torch_available
:
# Set the USE_CUDA environment variable to select a GPU.
torch_device
=
"cuda"
if
os
.
environ
.
get
(
"USE_CUDA"
)
else
"cpu"
else
:
torch_device
=
None
transformers/tokenization_albert.py
0 → 100644
View file @
0558c9cb
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for ALBERT model."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
from
.tokenization_utils
import
PreTrainedTokenizer
import
logging
import
unicodedata
import
six
import
os
from
shutil
import
copyfile
logger
=
logging
.
getLogger
(
__name__
)
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'spiece.model'
}
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'albert-base-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-spiece.model"
,
'albert-large-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-spiece.model"
,
'albert-xlarge-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-spiece.model"
,
'albert-xxlarge-v1'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-spiece.model"
,
'albert-base-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-spiece.model"
,
'albert-large-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-spiece.model"
,
'albert-xlarge-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-spiece.model"
,
'albert-xxlarge-v2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-spiece.model"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'albert-base-v1'
:
512
,
'albert-large-v1'
:
512
,
'albert-xlarge-v1'
:
512
,
'albert-xxlarge-v1'
:
512
,
'albert-base-v2'
:
512
,
'albert-large-v2'
:
512
,
'albert-xlarge-v2'
:
512
,
'albert-xxlarge-v2'
:
512
,
}
SPIECE_UNDERLINE
=
u
'▁'
class
AlbertTokenizer
(
PreTrainedTokenizer
):
"""
SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
remove_space
=
True
,
keep_accents
=
False
,
bos_token
=
"[CLS]"
,
eos_token
=
"[SEP]"
,
unk_token
=
"<unk>"
,
sep_token
=
"[SEP]"
,
pad_token
=
"<pad>"
,
cls_token
=
"[CLS]"
,
mask_token
=
"[MASK]"
,
**
kwargs
):
super
(
AlbertTokenizer
,
self
).
__init__
(
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
**
kwargs
)
self
.
max_len_single_sentence
=
self
.
max_len
-
2
# take into account special tokens
self
.
max_len_sentences_pair
=
self
.
max_len
-
3
# take into account special tokens
try
:
import
sentencepiece
as
spm
except
ImportError
:
logger
.
warning
(
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self
.
do_lower_case
=
do_lower_case
self
.
remove_space
=
remove_space
self
.
keep_accents
=
keep_accents
self
.
vocab_file
=
vocab_file
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
vocab_file
)
@
property
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
try
:
import
sentencepiece
as
spm
except
ImportError
:
logger
.
warning
(
"You need to install SentencePiece to use AlbertTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
self
.
vocab_file
)
def
preprocess_text
(
self
,
inputs
):
if
self
.
remove_space
:
outputs
=
' '
.
join
(
inputs
.
strip
().
split
())
else
:
outputs
=
inputs
outputs
=
outputs
.
replace
(
"``"
,
'"'
).
replace
(
"''"
,
'"'
)
if
six
.
PY2
and
isinstance
(
outputs
,
str
):
outputs
=
outputs
.
decode
(
'utf-8'
)
if
not
self
.
keep_accents
:
outputs
=
unicodedata
.
normalize
(
'NFKD'
,
outputs
)
outputs
=
''
.
join
([
c
for
c
in
outputs
if
not
unicodedata
.
combining
(
c
)])
if
self
.
do_lower_case
:
outputs
=
outputs
.
lower
()
return
outputs
def
_tokenize
(
self
,
text
,
return_unicode
=
True
,
sample
=
False
):
""" Tokenize a string.
return_unicode is used only for py2
"""
text
=
self
.
preprocess_text
(
text
)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if
six
.
PY2
and
isinstance
(
text
,
unicode
):
text
=
text
.
encode
(
'utf-8'
)
if
not
sample
:
pieces
=
self
.
sp_model
.
EncodeAsPieces
(
text
)
else
:
pieces
=
self
.
sp_model
.
SampleEncodeAsPieces
(
text
,
64
,
0.1
)
new_pieces
=
[]
for
piece
in
pieces
:
if
len
(
piece
)
>
1
and
piece
[
-
1
]
==
str
(
','
)
and
piece
[
-
2
].
isdigit
():
cur_pieces
=
self
.
sp_model
.
EncodeAsPieces
(
piece
[:
-
1
].
replace
(
SPIECE_UNDERLINE
,
''
))
if
piece
[
0
]
!=
SPIECE_UNDERLINE
and
cur_pieces
[
0
][
0
]
==
SPIECE_UNDERLINE
:
if
len
(
cur_pieces
[
0
])
==
1
:
cur_pieces
=
cur_pieces
[
1
:]
else
:
cur_pieces
[
0
]
=
cur_pieces
[
0
][
1
:]
cur_pieces
.
append
(
piece
[
-
1
])
new_pieces
.
extend
(
cur_pieces
)
else
:
new_pieces
.
append
(
piece
)
# note(zhiliny): convert back to unicode for py2
if
six
.
PY2
and
return_unicode
:
ret_pieces
=
[]
for
piece
in
new_pieces
:
if
isinstance
(
piece
,
str
):
piece
=
piece
.
decode
(
'utf-8'
)
ret_pieces
.
append
(
piece
)
new_pieces
=
ret_pieces
return
new_pieces
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str/unicode) in an id using the vocab. """
return
self
.
sp_model
.
PieceToId
(
token
)
def
_convert_id_to_token
(
self
,
index
,
return_unicode
=
True
):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
token
=
self
.
sp_model
.
IdToPiece
(
index
)
if
six
.
PY2
and
return_unicode
and
isinstance
(
token
,
str
):
token
=
token
.
decode
(
'utf-8'
)
return
token
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string
=
''
.
join
(
tokens
).
replace
(
SPIECE_UNDERLINE
,
' '
).
strip
()
return
out_string
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
An ALBERT sequence has the following format:
single sequence: [CLS] X [SEP]
pair of sequences: [CLS] A [SEP] B [SEP]
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
cls
+
token_ids_0
+
sep
return
cls
+
token_ids_0
+
sep
+
token_ids_1
+
sep
def
get_special_tokens_mask
(
self
,
token_ids_0
,
token_ids_1
=
None
,
already_has_special_tokens
=
False
):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
if
already_has_special_tokens
:
if
token_ids_1
is
not
None
:
raise
ValueError
(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return
list
(
map
(
lambda
x
:
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
,
token_ids_0
))
if
token_ids_1
is
not
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
+
([
0
]
*
len
(
token_ids_1
))
+
[
1
]
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
,
token_ids_1
=
None
):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
An ALBERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
+
len
(
token_ids_1
+
sep
)
*
[
1
]
def
save_vocabulary
(
self
,
save_directory
):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
transformers/tokenization_auto.py
View file @
0558c9cb
...
@@ -27,6 +27,8 @@ from .tokenization_xlnet import XLNetTokenizer
...
@@ -27,6 +27,8 @@ from .tokenization_xlnet import XLNetTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_xlm
import
XLMTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_roberta
import
RobertaTokenizer
from
.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_distilbert
import
DistilBertTokenizer
from
.tokenization_camembert
import
CamembertTokenizer
from
.tokenization_albert
import
AlbertTokenizer
from
.tokenization_t5
import
T5Tokenizer
from
.tokenization_t5
import
T5Tokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -44,14 +46,16 @@ class AutoTokenizer(object):
...
@@ -44,14 +46,16 @@ class AutoTokenizer(object):
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
This class cannot be instantiated using `__init__()` (throw an error).
"""
"""
...
@@ -68,14 +72,16 @@ class AutoTokenizer(object):
...
@@ -68,14 +72,16 @@ class AutoTokenizer(object):
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Tokenizer (T5 model)
- contains `t5`: T5Tokenizer (T5 model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
- contains `roberta`: RobertaTokenizer (XLM model)
- contains `albert`: AlbertTokenizer (ALBERT model)
- contains `camembert`: CamembertTokenizer (CamemBERT model)
- contains `roberta`: RobertaTokenizer (RoBERTa model)
- contains `bert`: BertTokenizer (Bert model)
- contains `bert`: BertTokenizer (Bert model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlnet`: XLNetTokenizer (XLNet model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `xlm`: XLMTokenizer (XLM model)
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
Params:
Params:
pretrained_model_name_or_path: either:
pretrained_model_name_or_path: either:
...
@@ -90,6 +96,9 @@ class AutoTokenizer(object):
...
@@ -90,6 +96,9 @@ class AutoTokenizer(object):
force_download: (`optional`) boolean, default False:
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
Force to (re-)download the vocabulary files and override the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
The proxies are used on each request.
...
@@ -108,6 +117,10 @@ class AutoTokenizer(object):
...
@@ -108,6 +117,10 @@ class AutoTokenizer(object):
return
T5Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
T5Tokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
DistilBertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'albert'
in
pretrained_model_name_or_path
:
return
AlbertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'camembert'
in
pretrained_model_name_or_path
:
return
CamembertTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
RobertaTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
elif
'bert'
in
pretrained_model_name_or_path
:
...
@@ -126,4 +139,4 @@ class AutoTokenizer(object):
...
@@ -126,4 +139,4 @@ class AutoTokenizer(object):
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
return
CTRLTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta', '
ctrl
'"
.
format
(
pretrained_model_name_or_path
))
"'xlm', 'roberta', '
distilbert,' 'camembert', 'ctrl', 'albert
'"
.
format
(
pretrained_model_name_or_path
))
transformers/tokenization_bert.py
View file @
0558c9cb
...
@@ -220,7 +220,7 @@ class BertTokenizer(PreTrainedTokenizer):
...
@@ -220,7 +220,7 @@ class BertTokenizer(PreTrainedTokenizer):
special tokens for the model
special tokens for the model
Returns:
Returns:
A list of integers in the range [0, 1]:
0
for a special token,
1
for a sequence token.
A list of integers in the range [0, 1]:
1
for a special token,
0
for a sequence token.
"""
"""
if
already_has_special_tokens
:
if
already_has_special_tokens
:
...
...
transformers/tokenization_camembert.py
0 → 100644
View file @
0558c9cb
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
""" Tokenization classes for Camembert model."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
logging
import
os
from
shutil
import
copyfile
import
sentencepiece
as
spm
from
transformers.tokenization_utils
import
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'sentencepiece.bpe.model'
}
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'camembert-base'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'camembert-base'
:
None
,
}
class
CamembertTokenizer
(
PreTrainedTokenizer
):
"""
Adapted from RobertaTokenizer and XLNetTokenizer
SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
'<pad>'
,
mask_token
=
'<mask>'
,
additional_special_tokens
=
[
'<s>NOTUSED'
,
'</s>NOTUSED'
],
**
kwargs
):
super
(
CamembertTokenizer
,
self
).
__init__
(
max_len
=
512
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
additional_special_tokens
=
additional_special_tokens
,
**
kwargs
)
self
.
max_len_single_sentence
=
self
.
max_len
-
2
# take into account special tokens
self
.
max_len_sentences_pair
=
self
.
max_len
-
4
# take into account special tokens
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
str
(
vocab_file
))
self
.
vocab_file
=
vocab_file
# HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual
# sentencepiece vocabulary (this is the case for <s> and </s>
self
.
fairseq_tokens_to_ids
=
{
'<s>NOTUSED'
:
0
,
'<pad>'
:
1
,
'</s>NOTUSED'
:
2
,
'<unk>'
:
3
}
self
.
fairseq_offset
=
len
(
self
.
fairseq_tokens_to_ids
)
self
.
fairseq_tokens_to_ids
[
'<mask>'
]
=
len
(
self
.
sp_model
)
+
len
(
self
.
fairseq_tokens_to_ids
)
self
.
fairseq_ids_to_tokens
=
{
v
:
k
for
k
,
v
in
self
.
fairseq_tokens_to_ids
.
items
()}
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A RoBERTa sequence has the following format:
single sequence: <s> X </s>
pair of sequences: <s> A </s></s> B </s>
"""
if
token_ids_1
is
None
:
return
[
self
.
cls_token_id
]
+
token_ids_0
+
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
sep
=
[
self
.
sep_token_id
]
return
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
def
get_special_tokens_mask
(
self
,
token_ids_0
,
token_ids_1
=
None
,
already_has_special_tokens
=
False
):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
if
token_ids_1
is
not
None
:
raise
ValueError
(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return
list
(
map
(
lambda
x
:
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
,
token_ids_0
))
if
token_ids_1
is
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
,
1
]
+
([
0
]
*
len
(
token_ids_1
))
+
[
1
]
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
,
token_ids_1
=
None
):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A RoBERTa sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
)
*
[
0
]
+
len
(
token_ids_1
+
sep
)
*
[
1
]
@
property
def
vocab_size
(
self
):
return
len
(
self
.
fairseq_tokens_to_ids
)
+
len
(
self
.
sp_model
)
def
_tokenize
(
self
,
text
):
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str/unicode) in an id using the vocab. """
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
elif
self
.
sp_model
.
PieceToId
(
token
)
==
0
:
# Convert sentence piece unk token to fairseq unk token index
return
self
.
unk_token_id
return
self
.
fairseq_offset
+
self
.
sp_model
.
PieceToId
(
token
)
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
def
save_vocabulary
(
self
,
save_directory
):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file
to a directory.
"""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
transformers/tokenization_ctrl.py
View file @
0558c9cb
...
@@ -133,9 +133,11 @@ class CTRLTokenizer(PreTrainedTokenizer):
...
@@ -133,9 +133,11 @@ class CTRLTokenizer(PreTrainedTokenizer):
self
.
max_len_single_sentence
=
self
.
max_len
# no default special tokens - you can update this value if you add special tokens
self
.
max_len_single_sentence
=
self
.
max_len
# no default special tokens - you can update this value if you add special tokens
self
.
max_len_sentences_pair
=
self
.
max_len
# no default special tokens - you can update this value if you add special tokens
self
.
max_len_sentences_pair
=
self
.
max_len
# no default special tokens - you can update this value if you add special tokens
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
self
.
encoder
=
json
.
load
(
vocab_handle
)
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
with
open
(
merges_file
,
encoding
=
'utf-8'
)
as
merges_handle
:
merges
=
merges_handle
.
read
().
split
(
'
\n
'
)[
1
:
-
1
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
cache
=
{}
...
@@ -192,9 +194,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
...
@@ -192,9 +194,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
"""
"""
split_tokens
=
[]
split_tokens
=
[]
text
=
text
.
split
(
' '
)
words
=
re
.
findall
(
r
'\S+\n?'
,
text
)
for
token
in
text
:
for
token
in
words
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
' '
)])
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
' '
)])
return
split_tokens
return
split_tokens
...
...
transformers/tokenization_distilbert.py
View file @
0558c9cb
...
@@ -33,12 +33,16 @@ PRETRAINED_VOCAB_FILES_MAP = {
...
@@ -33,12 +33,16 @@ PRETRAINED_VOCAB_FILES_MAP = {
{
{
'distilbert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt"
,
'distilbert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt"
,
'distilbert-base-uncased-distilled-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt"
,
'distilbert-base-uncased-distilled-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt"
,
'distilbert-base-german-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-german-cased-vocab.txt"
,
'distilbert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
}
}
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'distilbert-base-uncased'
:
512
,
'distilbert-base-uncased'
:
512
,
'distilbert-base-uncased-distilled-squad'
:
512
,
'distilbert-base-uncased-distilled-squad'
:
512
,
'distilbert-base-german-cased'
:
512
,
'distilbert-base-multilingual-cased'
:
512
,
}
}
...
...
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment