Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f31154cb
Commit
f31154cb
authored
Jul 16, 2019
by
thomwolf
Browse files
Merge branch 'xlnet'
parents
78462aad
1b35d05d
Changes
125
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2739 additions
and
334 deletions
+2739
-334
pytorch_transformers/tests/modeling_gpt2_test.py
pytorch_transformers/tests/modeling_gpt2_test.py
+48
-0
pytorch_transformers/tests/modeling_openai_test.py
pytorch_transformers/tests/modeling_openai_test.py
+48
-0
pytorch_transformers/tests/modeling_transfo_xl_test.py
pytorch_transformers/tests/modeling_transfo_xl_test.py
+67
-90
pytorch_transformers/tests/modeling_xlm_test.py
pytorch_transformers/tests/modeling_xlm_test.py
+293
-0
pytorch_transformers/tests/modeling_xlnet_test.py
pytorch_transformers/tests/modeling_xlnet_test.py
+322
-0
pytorch_transformers/tests/optimization_test.py
pytorch_transformers/tests/optimization_test.py
+105
-0
pytorch_transformers/tests/tokenization_bert_test.py
pytorch_transformers/tests/tokenization_bert_test.py
+19
-36
pytorch_transformers/tests/tokenization_gpt2_test.py
pytorch_transformers/tests/tokenization_gpt2_test.py
+62
-0
pytorch_transformers/tests/tokenization_openai_test.py
pytorch_transformers/tests/tokenization_openai_test.py
+64
-0
pytorch_transformers/tests/tokenization_tests_commons.py
pytorch_transformers/tests/tokenization_tests_commons.py
+136
-0
pytorch_transformers/tests/tokenization_transfo_xl_test.py
pytorch_transformers/tests/tokenization_transfo_xl_test.py
+66
-0
pytorch_transformers/tests/tokenization_utils_test.py
pytorch_transformers/tests/tokenization_utils_test.py
+46
-0
pytorch_transformers/tests/tokenization_xlm_test.py
pytorch_transformers/tests/tokenization_xlm_test.py
+63
-0
pytorch_transformers/tests/tokenization_xlnet_test.py
pytorch_transformers/tests/tokenization_xlnet_test.py
+84
-0
pytorch_transformers/tokenization_bert.py
pytorch_transformers/tokenization_bert.py
+124
-120
pytorch_transformers/tokenization_gpt2.py
pytorch_transformers/tokenization_gpt2.py
+216
-0
pytorch_transformers/tokenization_openai.py
pytorch_transformers/tokenization_openai.py
+204
-0
pytorch_transformers/tokenization_transfo_xl.py
pytorch_transformers/tokenization_transfo_xl.py
+68
-88
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+473
-0
pytorch_transformers/tokenization_xlm.py
pytorch_transformers/tokenization_xlm.py
+231
-0
No files found.
pytorch_transformers/tests/modeling_gpt2_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
pytest
from
pytorch_transformers
import
(
GPT2Config
,
GPT2Model
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
)
from
.modeling_common_test
import
CommonTestCases
,
ConfigTester
class
GPT2ModelTest
(
unittest
.
TestCase
):
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
GPT2Config
,
n_embd
=
37
)
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
True
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
GPT2Config
,
base_model_class
=
GPT2Model
,
lm_head_model_class
=
GPT2LMHeadModel
,
double_head_model_class
=
GPT2DoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/modeling_openai_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
pytest
from
pytorch_transformers
import
(
OpenAIGPTConfig
,
OpenAIGPTModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTDoubleHeadsModel
)
from
.modeling_common_test
import
CommonTestCases
,
ConfigTester
class
OpenAIModelTest
(
unittest
.
TestCase
):
def
test_config
(
self
):
config_tester
=
ConfigTester
(
self
,
config_class
=
OpenAIGPTConfig
,
n_embd
=
37
)
config_tester
.
run_common_tests
()
def
test_model
(
self
):
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_common_tests
(
test_presents
=
False
)
@
pytest
.
mark
.
slow
def
test_pretrained
(
self
):
model_tester
=
CommonTestCases
.
GPTModelTester
(
self
,
config_class
=
OpenAIGPTConfig
,
base_model_class
=
OpenAIGPTModel
,
lm_head_model_class
=
OpenAIGPTLMHeadModel
,
double_head_model_class
=
OpenAIGPTDoubleHeadsModel
)
model_tester
.
run_slow_tests
()
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/modeling_transfo_xl_test.py
→
pytorch_transformers/
tests/modeling_transfo_xl_test.py
View file @
f31154cb
...
@@ -25,10 +25,18 @@ import pytest
...
@@ -25,10 +25,18 @@ import pytest
import
torch
import
torch
from
pytorch_pretrained_bert
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_transformers
import
(
TransfoXLConfig
,
TransfoXLModel
,
TransfoXLLMHeadModel
)
from
pytorch_pretrained_bert.modeling_transfo_xl
import
PRETRAINED_MODEL_ARCHIVE_MAP
from
pytorch_transformers.modeling_transfo_xl
import
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
ConfigTester
,
CommonTestCases
,
ids_tensor
class
TransfoXLModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
TransfoXLModel
,
TransfoXLLMHeadModel
)
test_pruning
=
False
test_torchscript
=
False
test_resize_embeddings
=
False
class
TransfoXLModelTest
(
unittest
.
TestCase
):
class
TransfoXLModelTester
(
object
):
class
TransfoXLModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -41,54 +49,56 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -41,54 +49,56 @@ class TransfoXLModelTest(unittest.TestCase):
use_labels
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
cutoffs
=
[
10
,
50
,
80
],
d_model
=
32
,
hidden_size
=
32
,
d_embed
=
32
,
d_embed
=
32
,
n_head
=
4
,
num_attentio
n_head
s
=
4
,
d_head
=
8
,
d_head
=
8
,
d_inner
=
128
,
d_inner
=
128
,
div_val
=
2
,
div_val
=
2
,
n_layer
=
5
,
num_hidde
n_layer
s
=
5
,
scope
=
None
,
scope
=
None
,
seed
=
1
):
seed
=
1
,
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
key_len
=
seq_length
+
mem_len
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
cutoffs
=
cutoffs
self
.
d_model
=
d_model
self
.
hidden_size
=
hidden_size
self
.
d_embed
=
d_embed
self
.
d_embed
=
d_embed
self
.
n_head
=
n_head
self
.
n
um_attention
_head
s
=
num_attentio
n_head
s
self
.
d_head
=
d_head
self
.
d_head
=
d_head
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
div_val
=
div_val
self
.
div_val
=
div_val
self
.
n_layer
=
n_layer
self
.
n
um_hidden
_layer
s
=
num_hidde
n_layer
s
self
.
scope
=
scope
self
.
scope
=
scope
self
.
seed
=
seed
self
.
seed
=
seed
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
None
lm_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
lm_labels
=
TransfoXLModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
TransfoXLConfig
(
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
mem_len
=
self
.
mem_len
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
clamp_len
=
self
.
clamp_len
,
cutoffs
=
self
.
cutoffs
,
cutoffs
=
self
.
cutoffs
,
d_model
=
self
.
d_model
,
d_model
=
self
.
hidden_size
,
d_embed
=
self
.
d_embed
,
d_embed
=
self
.
d_embed
,
n_head
=
self
.
n_head
,
n_head
=
self
.
n
um_attention
_head
s
,
d_head
=
self
.
d_head
,
d_head
=
self
.
d_head
,
d_inner
=
self
.
d_inner
,
d_inner
=
self
.
d_inner
,
div_val
=
self
.
div_val
,
div_val
=
self
.
div_val
,
n_layer
=
self
.
n_layer
)
n_layer
=
self
.
n
um_hidden
_layer
s
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
return
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
...
@@ -113,37 +123,34 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -113,37 +123,34 @@ class TransfoXLModelTest(unittest.TestCase):
def
check_transfo_xl_model_output
(
self
,
result
):
def
check_transfo_xl_model_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_1"
].
size
()),
list
(
result
[
"hidden_states_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
d_model
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"hidden_states_2"
].
size
()),
list
(
result
[
"hidden_states_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
d_model
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
n
um_hidden
_layer
s
)
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
def
create_transfo_xl_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
model
=
TransfoXLLMHeadModel
(
config
)
model
=
TransfoXLLMHeadModel
(
config
)
model
.
eval
()
model
.
eval
()
loss_1
,
mems_1a
=
model
(
input_ids_1
,
target
=
lm_labels
)
lm_logits_1
,
mems_1
=
model
(
input_ids_1
)
lm_logits_1
,
mems_1b
=
model
(
input_ids_1
)
loss_1
,
_
,
mems_1
=
model
(
input_ids_1
,
labels
=
lm_labels
)
lm_logits_2
,
mems_2
=
model
(
input_ids_2
,
mems
=
mems_1
)
loss_2
,
mems_2a
=
model
(
input_ids_2
,
target
=
lm_labels
,
mems
=
mems_1a
)
loss_2
,
_
,
mems_2
=
model
(
input_ids_2
,
labels
=
lm_labels
,
mems
=
mems_1
)
lm_logits_2
,
mems_2b
=
model
(
input_ids_2
,
mems
=
mems_1b
)
outputs
=
{
outputs
=
{
"loss_1"
:
loss_1
,
"loss_1"
:
loss_1
,
"mems_1
a
"
:
mems_1
a
,
"mems_1"
:
mems_1
,
"lm_logits_1"
:
lm_logits_1
,
"lm_logits_1"
:
lm_logits_1
,
"mems_1b"
:
mems_1b
,
"loss_2"
:
loss_2
,
"loss_2"
:
loss_2
,
"mems_2
a
"
:
mems_2
a
,
"mems_2"
:
mems_2
,
"lm_logits_2"
:
lm_logits_2
,
"lm_logits_2"
:
lm_logits_2
,
"mems_2b"
:
mems_2b
,
}
}
return
outputs
return
outputs
...
@@ -155,14 +162,8 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -155,14 +162,8 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
result
[
"lm_logits_1"
].
size
()),
list
(
result
[
"lm_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_1b"
]))
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
list
(
result
[
"loss_2"
].
size
()),
...
@@ -171,66 +172,42 @@ class TransfoXLModelTest(unittest.TestCase):
...
@@ -171,66 +172,42 @@ class TransfoXLModelTest(unittest.TestCase):
list
(
result
[
"lm_logits_2"
].
size
()),
list
(
result
[
"lm_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2a"
]),
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2b"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
d_model
]]
*
self
.
n_layer
)
self
.
parent
.
assertListEqual
(
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2a"
]),
list
(
mem
[
~
torch
.
isnan
(
mem
)].
sum
()
for
mem
in
result
[
"mems_2b"
]))
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
def
test_config_to_json_string
(
self
):
config
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"n_token"
],
96
)
self
.
assertEqual
(
obj
[
"d_embed"
],
37
)
def
test_config_to_json_file
(
self
):
config_first
=
TransfoXLConfig
(
vocab_size_or_config_json_file
=
96
,
d_embed
=
37
)
json_file_path
=
"/tmp/config.json"
config_first
.
to_json_file
(
json_file_path
)
config_second
=
TransfoXLConfig
.
from_json_file
(
json_file_path
)
os
.
remove
(
json_file_path
)
self
.
assertEqual
(
config_second
.
to_dict
(),
config_first
.
to_dict
())
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
def
run_tester
(
self
,
tester
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
lm_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
tester
.
set_seed
()
output_result
=
tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
tester
.
check_transfo_xl_model_output
(
output_result
)
tester
.
set_seed
()
def
setUp
(
self
):
output_result
=
tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
self
.
model_tester
=
TransfoXLModelTest
.
TransfoXLModelTester
(
self
)
tester
.
check_transfo_xl_lm_head_output
(
output_result
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
TransfoXLConfig
,
d_embed
=
37
)
@
classmethod
def
test_config
(
self
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
self
.
config_tester
.
run_common_tests
()
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
rng
=
random
.
Random
()
total_dims
=
1
def
test_transfo_xl_model
(
self
):
for
dim
in
shape
:
self
.
model_tester
.
set_seed
()
total_dims
*=
dim
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
output_result
=
self
.
model_tester
.
create_transfo_xl_model
(
*
config_and_inputs
)
self
.
model_tester
.
check_transfo_xl_model_output
(
output_result
)
values
=
[]
def
test_transfo_xl_lm_head
(
self
):
for
_
in
range
(
total_dims
):
self
.
model_tester
.
set_seed
()
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
output_result
=
self
.
model_tester
.
create_transfo_xl_lm_head
(
*
config_and_inputs
)
self
.
model_tester
.
check_transfo_xl_lm_head_output
(
output_result
)
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TransfoXLModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
pytorch_transformers/tests/modeling_xlm_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
from
pytorch_transformers
import
(
XLMConfig
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
from
pytorch_transformers.modeling_xlm
import
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
class
XLMModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLMModel
,
XLMWithLMHeadModel
,
XLMForQuestionAnswering
,
XLMForSequenceClassification
)
# , XLMForSequenceClassification, XLMForTokenClassification),
class
XLMModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_input_lengths
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
gelu_activation
=
True
,
sinusoidal_embeddings
=
False
,
causal
=
False
,
asm
=
False
,
n_langs
=
2
,
vocab_size
=
99
,
n_special
=
0
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
summary_type
=
"last"
,
use_proj
=
True
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_lengths
=
use_input_lengths
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
gelu_activation
=
gelu_activation
self
.
sinusoidal_embeddings
=
sinusoidal_embeddings
self
.
asm
=
asm
self
.
n_langs
=
n_langs
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
summary_type
=
summary_type
self
.
causal
=
causal
self
.
use_proj
=
use_proj
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
n_langs
=
n_langs
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
summary_type
=
summary_type
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_lengths
=
None
if
self
.
use_input_lengths
:
input_lengths
=
ids_tensor
([
self
.
batch_size
],
vocab_size
=
2
)
+
self
.
seq_length
-
2
# small variation of seq_length
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
n_langs
)
sequence_labels
=
None
token_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
config
=
XLMConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
emb_dim
=
self
.
hidden_size
,
n_layers
=
self
.
num_hidden_layers
,
n_heads
=
self
.
num_attention_heads
,
dropout
=
self
.
hidden_dropout_prob
,
attention_dropout
=
self
.
attention_probs_dropout_prob
,
gelu_activation
=
self
.
gelu_activation
,
sinusoidal_embeddings
=
self
.
sinusoidal_embeddings
,
asm
=
self
.
asm
,
causal
=
self
.
causal
,
n_langs
=
self
.
n_langs
,
max_position_embeddings
=
self
.
max_position_embeddings
,
initializer_range
=
self
.
initializer_range
,
summary_type
=
self
.
summary_type
,
use_proj
=
self
.
use_proj
)
return
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_and_check_xlm_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMModel
(
config
=
config
)
model
.
eval
()
outputs
=
model
(
input_ids
,
lengths
=
input_lengths
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
,
langs
=
token_type_ids
)
outputs
=
model
(
input_ids
)
sequence_output
=
outputs
[
0
]
result
=
{
"sequence_output"
:
sequence_output
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_xlm_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMWithLMHeadModel
(
config
)
model
.
eval
()
loss
,
logits
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_xlm_qa
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMForQuestionAnswering
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
(
total_loss
,)
=
outputs
outputs
=
model
(
input_ids
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
(
total_loss
,)
=
outputs
result
=
{
"loss"
:
total_loss
,
"start_top_log_probs"
:
start_top_log_probs
,
"start_top_index"
:
start_top_index
,
"end_top_log_probs"
:
end_top_log_probs
,
"end_top_index"
:
end_top_index
,
"cls_logits"
:
cls_logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
def
create_and_check_xlm_sequence_classif
(
self
,
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
):
model
=
XLMForSequenceClassification
(
config
)
model
.
eval
()
(
logits
,)
=
model
(
input_ids
)
loss
,
logits
=
model
(
input_ids
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_lengths
,
sequence_labels
,
token_labels
,
is_impossible_labels
,
input_mask
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'lengths'
:
input_lengths
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
XLMModelTest
.
XLMModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLMConfig
,
emb_dim
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_xlm_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlm_model
(
*
config_and_inputs
)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_masked_lm(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_question_answering(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_sequence_classification(*config_and_inputs)
# config_and_inputs = tester.prepare_config_and_inputs()
# tester.create_and_check_xlm_for_token_classification(*config_and_inputs)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLMModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/modeling_xlnet_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
unittest
import
json
import
random
import
shutil
import
pytest
import
torch
from
pytorch_transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
pytorch_transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
ConfigTester
,
CommonTestCases
,
ids_tensor
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
test_pruning
=
False
class
XLNetModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
mem_len
=
10
,
clamp_len
=-
1
,
reuse_len
=
15
,
is_training
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
cutoffs
=
[
10
,
50
,
80
],
hidden_size
=
32
,
num_attention_heads
=
4
,
d_inner
=
128
,
num_hidden_layers
=
5
,
max_position_embeddings
=
10
,
type_sequence_label_size
=
2
,
untie_r
=
True
,
bi_data
=
False
,
same_length
=
False
,
initializer_range
=
0.05
,
seed
=
1
,
type_vocab_size
=
2
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
mem_len
=
mem_len
# self.key_len = seq_length + mem_len
self
.
clamp_len
=
clamp_len
self
.
reuse_len
=
reuse_len
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
cutoffs
=
cutoffs
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
self
.
d_inner
=
d_inner
self
.
num_hidden_layers
=
num_hidden_layers
self
.
max_position_embeddings
=
max_position_embeddings
self
.
bi_data
=
bi_data
self
.
untie_r
=
untie_r
self
.
same_length
=
same_length
self
.
initializer_range
=
initializer_range
self
.
seed
=
seed
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
def
prepare_config_and_inputs
(
self
):
input_ids_1
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids_2
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
segment_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
2
).
float
()
input_ids_q
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
+
1
],
self
.
vocab_size
)
perm_mask
=
torch
.
zeros
(
self
.
batch_size
,
self
.
seq_length
+
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
perm_mask
[:,
:,
-
1
]
=
1.0
# Previous tokens don't see last token
target_mapping
=
torch
.
zeros
(
self
.
batch_size
,
1
,
self
.
seq_length
+
1
,
dtype
=
torch
.
float
)
target_mapping
[:,
0
,
-
1
]
=
1.0
# predict last token
sequence_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
n_head
=
self
.
num_attention_heads
,
d_inner
=
self
.
d_inner
,
n_layer
=
self
.
num_hidden_layers
,
untie_r
=
self
.
untie_r
,
max_position_embeddings
=
self
.
max_position_embeddings
,
mem_len
=
self
.
mem_len
,
clamp_len
=
self
.
clamp_len
,
same_length
=
self
.
same_length
,
reuse_len
=
self
.
reuse_len
,
bi_data
=
self
.
bi_data
,
initializer_range
=
self
.
initializer_range
,
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetModel
(
config
)
model
.
eval
()
_
,
_
=
model
(
input_ids_1
,
input_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
attention_mask
=
input_mask
)
_
,
_
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
)
outputs
,
mems_1
=
model
(
input_ids_1
)
result
=
{
"mems_1"
:
mems_1
,
"outputs"
:
outputs
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"outputs"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
loss_1
,
all_logits_1
,
mems_1
=
model
(
input_ids_1
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
)
loss_2
,
all_logits_2
,
mems_2
=
model
(
input_ids_2
,
token_type_ids
=
segment_ids
,
labels
=
lm_labels
,
mems
=
mems_1
)
logits
,
_
=
model
(
input_ids_q
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
)
result
=
{
"loss_1"
:
loss_1
,
"mems_1"
:
mems_1
,
"all_logits_1"
:
all_logits_1
,
"loss_2"
:
loss_2
,
"mems_2"
:
mems_2
,
"all_logits_2"
:
all_logits_2
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_1"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_1"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss_2"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"all_logits_2"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_2"
]),
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForQuestionAnswering
(
config
)
model
.
eval
()
outputs
=
model
(
input_ids_1
)
start_top_log_probs
,
start_top_index
,
end_top_log_probs
,
end_top_index
,
cls_logits
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
,
p_mask
=
input_mask
)
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
,
cls_index
=
sequence_labels
,
is_impossible
=
is_impossible_labels
)
total_loss
,
mems
=
outputs
outputs
=
model
(
input_ids_1
,
start_positions
=
sequence_labels
,
end_positions
=
sequence_labels
)
total_loss
,
mems
=
outputs
result
=
{
"loss"
:
total_loss
,
"start_top_log_probs"
:
start_top_log_probs
,
"start_top_index"
:
start_top_index
,
"end_top_log_probs"
:
end_top_log_probs
,
"end_top_index"
:
end_top_index
,
"cls_logits"
:
cls_logits
,
"mems"
:
mems
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"start_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_log_probs"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"end_top_index"
].
size
()),
[
self
.
batch_size
,
model
.
config
.
start_n_top
*
model
.
config
.
end_n_top
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"cls_logits"
].
size
()),
[
self
.
batch_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
model
=
XLNetForSequenceClassification
(
config
)
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
loss
,
logits
,
mems_1
=
model
(
input_ids_1
,
labels
=
sequence_labels
)
result
=
{
"loss"
:
loss
,
"mems_1"
:
mems_1
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
XLNetModelTest
.
XLNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
XLNetConfig
,
d_inner
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_xlnet_base_model
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_base_model
(
*
config_and_inputs
)
def
test_xlnet_lm_head
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_lm_head
(
*
config_and_inputs
)
def
test_xlnet_sequence_classif
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_qa
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
XLNetModel
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/optimization_test.py
→
pytorch_transformers/
tests/optimization_test.py
View file @
f31154cb
...
@@ -20,13 +20,19 @@ import unittest
...
@@ -20,13 +20,19 @@ import unittest
import
torch
import
torch
from
pytorch_pretrained_bert
import
BertAdam
from
pytorch_transformers
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
from
pytorch_pretrained_bert
import
OpenAIAdam
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
from
pytorch_pretrained_bert.optimization
import
ConstantLR
,
WarmupLinearSchedule
,
WarmupConstantSchedule
,
\
WarmupCosineWithWarmupRestartsSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupCosineSchedule
import
numpy
as
np
import
numpy
as
np
def
unwrap_schedule
(
scheduler
,
num_steps
=
10
):
lrs
=
[]
for
_
in
range
(
num_steps
):
scheduler
.
step
()
lrs
.
append
(
scheduler
.
get_lr
())
return
lrs
class
OptimizationTest
(
unittest
.
TestCase
):
class
OptimizationTest
(
unittest
.
TestCase
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
...
@@ -34,14 +40,12 @@ class OptimizationTest(unittest.TestCase):
...
@@ -34,14 +40,12 @@ class OptimizationTest(unittest.TestCase):
for
a
,
b
in
zip
(
list1
,
list2
):
for
a
,
b
in
zip
(
list1
,
list2
):
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
def
test_adam
(
self
):
def
test_adam
_w
(
self
):
w
=
torch
.
tensor
([
0.1
,
-
0.2
,
-
0.1
],
requires_grad
=
True
)
w
=
torch
.
tensor
([
0.1
,
-
0.2
,
-
0.1
],
requires_grad
=
True
)
target
=
torch
.
tensor
([
0.4
,
0.2
,
-
0.5
])
target
=
torch
.
tensor
([
0.4
,
0.2
,
-
0.5
])
criterion
=
torch
.
nn
.
MSELoss
()
criterion
=
torch
.
nn
.
MSELoss
()
# No warmup, constant schedule, no gradient clipping
# No warmup, constant schedule, no gradient clipping
optimizer
=
BertAdam
(
params
=
[
w
],
lr
=
2e-1
,
optimizer
=
AdamW
(
params
=
[
w
],
lr
=
2e-1
,
weight_decay
=
0.0
)
weight_decay
=
0.0
,
max_grad_norm
=-
1
)
for
_
in
range
(
100
):
for
_
in
range
(
100
):
loss
=
criterion
(
w
,
target
)
loss
=
criterion
(
w
,
target
)
loss
.
backward
()
loss
.
backward
()
...
@@ -52,39 +56,49 @@ class OptimizationTest(unittest.TestCase):
...
@@ -52,39 +56,49 @@ class OptimizationTest(unittest.TestCase):
class
ScheduleInitTest
(
unittest
.
TestCase
):
class
ScheduleInitTest
(
unittest
.
TestCase
):
def
test_bert_sched_init
(
self
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
m
=
torch
.
nn
.
Linear
(
50
,
50
)
optimizer
=
AdamW
(
m
.
parameters
(),
lr
=
10.
)
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
num_steps
=
10
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
self
.
assertEqual
(
len
(
list1
),
len
(
list2
))
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
for
a
,
b
in
zip
(
list1
,
list2
):
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
WarmupLinearSchedule
))
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
# shouldn't fail
def
test_constant_scheduler
(
self
):
def
test_openai_sched_init
(
self
):
scheduler
=
ConstantLRSchedule
(
self
.
optimizer
)
m
=
torch
.
nn
.
Linear
(
50
,
50
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
optim
=
OpenAIAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
expected_learning_rates
=
[
10.
]
*
self
.
num_steps
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
optim
=
OpenAIAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
OpenAIAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
def
test_warmup_constant_scheduler
(
self
):
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
WarmupLinearSchedule
))
scheduler
=
WarmupConstantSchedule
(
self
.
optimizer
,
warmup_steps
=
4
)
# shouldn't fail
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
2.5
,
5.0
,
7.5
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
class
WarmupCosineWithRestartsTest
(
unittest
.
TestCase
):
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
def
test_it
(
self
):
m
=
WarmupCosineWithWarmupRestartsSchedule
(
warmup
=
0.05
,
t_total
=
1000.
,
cycles
=
5
)
def
test_warmup_linear_scheduler
(
self
):
x
=
np
.
arange
(
0
,
1000
)
scheduler
=
WarmupLinearSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
y
=
[
m
.
get_lr
(
xe
)
for
xe
in
x
]
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
y
=
np
.
asarray
(
y
)
expected_learning_rates
=
[
5.0
,
10.0
,
8.75
,
7.5
,
6.25
,
5.0
,
3.75
,
2.5
,
1.25
,
0.0
]
expected_zeros
=
y
[[
0
,
200
,
400
,
600
,
800
]]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
print
(
expected_zeros
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
expected_ones
=
y
[[
50
,
250
,
450
,
650
,
850
]]
print
(
expected_ones
)
def
test_warmup_cosine_scheduler
(
self
):
self
.
assertTrue
(
np
.
allclose
(
expected_ones
,
1
))
scheduler
=
WarmupCosineSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
self
.
assertTrue
(
np
.
allclose
(
expected_zeros
,
0
))
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
5.0
,
10.0
,
9.61
,
8.53
,
6.91
,
5.0
,
3.08
,
1.46
,
0.38
,
0.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
def
test_warmup_cosine_hard_restart_scheduler
(
self
):
scheduler
=
WarmupCosineWithHardRestartsSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
cycles
=
2
,
t_total
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
expected_learning_rates
=
[
5.0
,
10.0
,
8.53
,
5.0
,
1.46
,
10.0
,
8.53
,
5.0
,
1.46
,
0.0
]
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/tokenization_test.py
→
pytorch_transformers/
tests/tokenization_
bert_
test.py
View file @
f31154cb
...
@@ -17,54 +17,37 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,54 +17,37 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
from
io
import
open
from
io
import
open
import
shutil
import
pytest
from
pytorch_
pretrained_b
er
t
.tokenization
import
(
BasicTokenizer
,
from
pytorch_
transform
er
s
.tokenization
_bert
import
(
BasicTokenizer
,
BertTokenizer
,
BertTokenizer
,
WordpieceTokenizer
,
WordpieceTokenizer
,
_is_control
,
_is_punctuation
,
_is_control
,
_is_punctuation
,
_is_whitespace
,
PRETRAINED_VOCAB_ARCHIVE_MAP
)
_is_whitespace
,
VOCAB_FILES_NAMES
)
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
TokenizationTest
(
unittest
.
TestCase
):
class
TokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
vocab_tokens
=
[
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
"##ing"
,
","
,
"low"
,
"lowest"
,
]
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
with
open
(
vocab_file
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
input_text
=
u
"UNwant
\u00E9
d,running"
output_text
=
u
"unwanted, running"
tokenizer
=
BertTokenizer
(
vocab_file
)
create_and_check_tokenizer_commons
(
self
,
input_text
,
output_text
,
BertTokenizer
,
tmpdirname
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
tokenizer
=
BertTokenizer
(
vocab_file
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertListEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
BertTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
def
test_chinese
(
self
):
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
tokenizer
=
BasicTokenizer
()
...
@@ -97,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
...
@@ -97,7 +80,7 @@ class TokenizationTest(unittest.TestCase):
vocab
=
{}
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
vocab
[
token
]
=
i
tokenizer
=
WordpieceTokenizer
(
vocab
=
vocab
)
tokenizer
=
WordpieceTokenizer
(
vocab
=
vocab
,
unk_token
=
"[UNK]"
)
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
""
),
[])
...
...
tests/tokenization_gpt2_test.py
→
pytorch_transformers/
tests/tokenization_gpt2_test.py
View file @
f31154cb
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,11 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
json
import
json
import
shutil
import
pytest
from
pytorch_
pretrained_b
er
t
.tokenization_gpt2
import
GPT2Tokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
pytorch_
transform
er
s
.tokenization_gpt2
import
GPT2Tokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
class
GPT2TokenizationTest
(
unittest
.
TestCase
):
...
@@ -29,49 +28,35 @@ class GPT2TokenizationTest(unittest.TestCase):
...
@@ -29,49 +28,35 @@ class GPT2TokenizationTest(unittest.TestCase):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"lo"
,
"low"
,
"er"
,
"lo"
,
"low"
,
"er"
,
"low"
,
"lowest"
,
"newer"
,
"wider"
]
"low"
,
"lowest"
,
"newer"
,
"wider"
,
"<unk>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r"
,
""
]
with
open
(
"/tmp/openai_tokenizer_vocab_test.json"
,
"w"
)
as
fp
:
special_tokens_map
=
{
"unk_token"
:
"<unk>"
}
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
vocab_file
=
fp
.
name
with
open
(
"/tmp/openai_tokenizer_merges_test.txt"
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
with
TemporaryDirectory
()
as
tmpdirname
:
os
.
remove
(
vocab_file
)
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
os
.
remove
(
merges_file
)
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
with
open
(
merges_file
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
text
=
"lower"
input_text
=
u
"lower newer"
bpe_tokens
=
[
"low"
,
"er"
]
output_text
=
u
"lower<unk>newer"
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
"<unk>"
]
create_and_check_tokenizer_commons
(
self
,
input_text
,
output_text
,
GPT2Tokenizer
,
tmpdirname
,
**
special_tokens_map
)
input_bpe_tokens
=
[
13
,
12
,
16
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_
file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merges_file
,
**
special_tokens_
map
)
tokenizer_2
=
GPT2Tokenizer
.
from_pretrained
(
"/tmp/"
)
text
=
"lower"
os
.
remove
(
vocab_file
)
bpe_tokens
=
[
"low"
,
"er"
]
os
.
remove
(
merges_file
)
tokens
=
tokenizer
.
tokenize
(
text
)
os
.
remove
(
speci
al
_
tokens
_file
)
self
.
assertListEqu
al
(
tokens
,
bpe_tokens
)
self
.
assertListEqual
(
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
input_bpe_tokens
=
[
13
,
12
,
17
]
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
self
.
assertListEqual
(
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
# @pytest.mark.slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
tests/tokenization_openai_test.py
→
pytorch_transformers/
tests/tokenization_openai_test.py
View file @
f31154cb
...
@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,10 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
import
json
import
json
import
shutil
import
pytest
from
pytorch_pretrained_bert.tokenization_openai
import
OpenAIGPTTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
pytorch_transformers.tokenization_openai
import
OpenAIGPTTokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
OpenAIGPTTokenizationTest
(
unittest
.
TestCase
):
class
OpenAIGPTTokenizationTest
(
unittest
.
TestCase
):
...
@@ -30,49 +30,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
...
@@ -30,49 +30,34 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
]
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
,
"<unk>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
merges
=
[
"#version: 0.2"
,
"l o"
,
"lo w"
,
"e r</w>"
,
""
]
with
open
(
"/tmp/openai_tokenizer_vocab_test.json"
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
vocab_file
=
fp
.
name
with
open
(
"/tmp/openai_tokenizer_merges_test.txt"
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
merges_file
=
fp
.
name
tokenizer
=
OpenAIGPTTokenizer
(
vocab_file
,
merges_file
,
special_tokens
=
[
"<unk>"
,
"<pad>"
])
with
TemporaryDirectory
()
as
tmpdirname
:
os
.
remove
(
vocab_file
)
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
os
.
remove
(
merges_file
)
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
with
open
(
merges_file
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
text
=
"lower"
input_text
=
u
"lower newer"
bpe_tokens
=
[
"low"
,
"er</w>"
]
output_text
=
u
"lower newer"
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
"<unk>"
]
create_and_check_tokenizer_commons
(
self
,
input_text
,
output_text
,
OpenAIGPTTokenizer
,
tmpdirname
)
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
vocab_file
,
merges_file
,
special_tokens_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
=
OpenAIGPTTokenizer
(
vocab_file
,
merges_file
)
tokenizer_2
=
OpenAIGPTTokenizer
.
from_pretrained
(
"/tmp/"
)
os
.
remove
(
vocab_file
)
os
.
remove
(
merges_file
)
os
.
remove
(
special_tokens_file
)
self
.
assertListEqual
(
text
=
"lower"
[
tokenizer
.
encoder
,
tokenizer
.
decoder
,
tokenizer
.
bpe_ranks
,
bpe_tokens
=
[
"low"
,
"er</w>"
]
tokenizer
.
special_tokens
,
tokenizer
.
special_tokens_decoder
],
tokens
=
tokenizer
.
tokenize
(
text
)
[
tokenizer_2
.
encoder
,
tokenizer_2
.
decoder
,
tokenizer_2
.
bpe_ranks
,
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
tokenizer_2
.
special_tokens
,
tokenizer_2
.
special_tokens_decoder
])
@
pytest
.
mark
.
slow
input_tokens
=
tokens
+
[
"<unk>"
]
def
test_tokenizer_from_pretrained
(
self
):
input_bpe_tokens
=
[
14
,
15
,
20
]
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
self
.
assertListEqual
(
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
=
OpenAIGPTTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
pytorch_transformers/tests/tokenization_tests_commons.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
sys
from
io
import
open
import
tempfile
import
shutil
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
class
TemporaryDirectory
(
object
):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def
__enter__
(
self
):
self
.
name
=
tempfile
.
mkdtemp
()
return
self
.
name
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
shutil
.
rmtree
(
self
.
name
)
else
:
import
pickle
TemporaryDirectory
=
tempfile
.
TemporaryDirectory
unicode
=
str
def
create_and_check_save_and_load_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
.
from_pretrained
(
*
inputs
,
**
kwargs
)
before_tokens
=
tokenizer
.
encode
(
u
"He is very happy, UNwant
\u00E9
d,running"
)
with
TemporaryDirectory
()
as
tmpdirname
:
tokenizer
.
save_pretrained
(
tmpdirname
)
tokenizer
=
tokenizer
.
from_pretrained
(
tmpdirname
)
after_tokens
=
tokenizer
.
encode
(
u
"He is very happy, UNwant
\u00E9
d,running"
)
tester
.
assertListEqual
(
before_tokens
,
after_tokens
)
def
create_and_check_pickle_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
.
from_pretrained
(
*
inputs
,
**
kwargs
)
tester
.
assertIsNotNone
(
tokenizer
)
text
=
u
"Munich and Berlin are nice cities"
subwords
=
tokenizer
.
tokenize
(
text
)
with
TemporaryDirectory
()
as
tmpdirname
:
filename
=
os
.
path
.
join
(
tmpdirname
,
u
"tokenizer.bin"
)
pickle
.
dump
(
tokenizer
,
open
(
filename
,
"wb"
))
tokenizer_new
=
pickle
.
load
(
open
(
filename
,
"rb"
))
subwords_loaded
=
tokenizer_new
.
tokenize
(
text
)
tester
.
assertListEqual
(
subwords
,
subwords_loaded
)
def
create_and_check_add_tokens_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
.
from_pretrained
(
*
inputs
,
**
kwargs
)
vocab_size
=
tokenizer
.
vocab_size
all_size
=
len
(
tokenizer
)
tester
.
assertNotEqual
(
vocab_size
,
0
)
tester
.
assertEqual
(
vocab_size
,
all_size
)
new_toks
=
[
"aaaaabbbbbb"
,
"cccccccccdddddddd"
]
added_toks
=
tokenizer
.
add_tokens
(
new_toks
)
vocab_size_2
=
tokenizer
.
vocab_size
all_size_2
=
len
(
tokenizer
)
tester
.
assertNotEqual
(
vocab_size_2
,
0
)
tester
.
assertEqual
(
vocab_size
,
vocab_size_2
)
tester
.
assertEqual
(
added_toks
,
len
(
new_toks
))
tester
.
assertEqual
(
all_size_2
,
all_size
+
len
(
new_toks
))
tokens
=
tokenizer
.
encode
(
"aaaaabbbbbb low cccccccccdddddddd l"
)
tester
.
assertGreaterEqual
(
len
(
tokens
),
4
)
tester
.
assertGreater
(
tokens
[
0
],
tokenizer
.
vocab_size
-
1
)
tester
.
assertGreater
(
tokens
[
-
2
],
tokenizer
.
vocab_size
-
1
)
new_toks_2
=
{
'eos_token'
:
">>>>|||<||<<|<<"
,
'pad_token'
:
"<<<<<|||>|>>>>|>"
}
added_toks_2
=
tokenizer
.
add_special_tokens
(
new_toks_2
)
vocab_size_3
=
tokenizer
.
vocab_size
all_size_3
=
len
(
tokenizer
)
tester
.
assertNotEqual
(
vocab_size_3
,
0
)
tester
.
assertEqual
(
vocab_size
,
vocab_size_3
)
tester
.
assertEqual
(
added_toks_2
,
len
(
new_toks_2
))
tester
.
assertEqual
(
all_size_3
,
all_size_2
+
len
(
new_toks_2
))
tokens
=
tokenizer
.
encode
(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l"
)
tester
.
assertGreaterEqual
(
len
(
tokens
),
6
)
tester
.
assertGreater
(
tokens
[
0
],
tokenizer
.
vocab_size
-
1
)
tester
.
assertGreater
(
tokens
[
0
],
tokens
[
1
])
tester
.
assertGreater
(
tokens
[
-
2
],
tokenizer
.
vocab_size
-
1
)
tester
.
assertGreater
(
tokens
[
-
2
],
tokens
[
-
3
])
tester
.
assertEqual
(
tokens
[
0
],
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
eos_token
))
tester
.
assertEqual
(
tokens
[
-
2
],
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
pad_token
))
def
create_and_check_required_methods_tokenizer
(
tester
,
input_text
,
output_text
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
tokenizer
=
tokenizer_class
.
from_pretrained
(
*
inputs
,
**
kwargs
)
tokens
=
tokenizer
.
tokenize
(
input_text
)
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
ids_2
=
tokenizer
.
encode
(
input_text
)
tester
.
assertListEqual
(
ids
,
ids_2
)
tokens_2
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
text_2
=
tokenizer
.
decode
(
ids
)
tester
.
assertEqual
(
text_2
,
output_text
)
tester
.
assertNotEqual
(
len
(
tokens_2
),
0
)
tester
.
assertIsInstance
(
text_2
,
(
str
,
unicode
))
def
create_and_check_tokenizer_commons
(
tester
,
input_text
,
output_text
,
tokenizer_class
,
*
inputs
,
**
kwargs
):
create_and_check_required_methods_tokenizer
(
tester
,
input_text
,
output_text
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
create_and_check_add_tokens_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
create_and_check_save_and_load_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
create_and_check_pickle_tokenizer
(
tester
,
tokenizer_class
,
*
inputs
,
**
kwargs
)
tests/tokenization_transfo_xl_test.py
→
pytorch_transformers/
tests/tokenization_transfo_xl_test.py
View file @
f31154cb
...
@@ -17,42 +17,35 @@ from __future__ import absolute_import, division, print_function, unicode_litera
...
@@ -17,42 +17,35 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import
os
import
os
import
unittest
import
unittest
from
io
import
open
from
io
import
open
import
shutil
import
pytest
from
pytorch_
pretrained_b
er
t
.tokenization_transfo_xl
import
TransfoXLTokenizer
,
PRETRAINED_VOCAB_ARCHIVE_MAP
from
pytorch_
transform
er
s
.tokenization_transfo_xl
import
TransfoXLTokenizer
,
VOCAB_FILES_NAMES
from
.
tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
class
TransfoXLTokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
vocab_tokens
=
[
vocab_tokens
=
[
"<unk>"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"unwanted"
,
"wa"
,
"un"
,
"running"
,
","
"<unk>"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"unwanted"
,
"wa"
,
"un"
,
"running"
,
","
,
"low"
,
"l"
,
]
]
with
open
(
"/tmp/transfo_xl_tokenizer_test.txt"
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
vocab_file
=
vocab_writer
.
name
with
open
(
vocab_file
,
"w"
,
encoding
=
'utf-8'
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
tokenizer
=
TransfoXLTokenizer
(
vocab_file
=
vocab_file
,
lower_case
=
True
)
input_text
=
u
"<unk> UNwanted , running"
tokenizer
.
build_vocab
()
output_text
=
u
"<unk> unwanted, running"
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwanted , running"
)
create_and_check_tokenizer_commons
(
self
,
input_text
,
output_text
,
TransfoXLTokenizer
,
tmpdirname
,
lower_case
=
True
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokenizer
=
TransfoXLTokenizer
(
vocab_file
=
vocab_file
,
lower_case
=
True
)
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
vocab_file
=
tokenizer
.
save_vocabulary
(
vocab_path
=
"/tmp/"
)
tokenizer
.
from_pretrained
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwanted , running"
)
tokens
=
tokenizer
.
tokenize
(
u
"<unk> UNwanted , running"
)
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokens
,
[
"<unk>"
,
"unwanted"
,
","
,
"running"
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
0
,
4
,
8
,
7
])
def
test_full_tokenizer_lower
(
self
):
def
test_full_tokenizer_lower
(
self
):
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
tokenizer
=
TransfoXLTokenizer
(
lower_case
=
True
)
...
@@ -68,13 +61,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
...
@@ -68,13 +61,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer
.
tokenize
(
u
"
\t
HeLLo ! how
\n
Are yoU ? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo ! how
\n
Are yoU ? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
@
pytest
.
mark
.
slow
def
test_tokenizer_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_pretrained_bert_test/"
for
model_name
in
list
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
())[:
1
]:
tokenizer
=
TransfoXLTokenizer
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
tokenizer
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
pytorch_transformers/tests/tokenization_utils_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
six
from
pytorch_transformers
import
PreTrainedTokenizer
from
pytorch_transformers.tokenization_gpt2
import
GPT2Tokenizer
class
TokenizerUtilsTest
(
unittest
.
TestCase
):
def
check_tokenizer_from_pretrained
(
self
,
tokenizer_class
):
s3_models
=
list
(
tokenizer_class
.
max_model_input_sizes
.
keys
())
for
model_name
in
s3_models
[:
1
]:
tokenizer
=
tokenizer_class
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
tokenizer
)
self
.
assertIsInstance
(
tokenizer
,
tokenizer_class
)
self
.
assertIsInstance
(
tokenizer
,
PreTrainedTokenizer
)
for
special_tok
in
tokenizer
.
all_special_tokens
:
if
six
.
PY2
:
self
.
assertIsInstance
(
special_tok
,
unicode
)
else
:
self
.
assertIsInstance
(
special_tok
,
str
)
special_tok_id
=
tokenizer
.
convert_tokens_to_ids
(
special_tok
)
self
.
assertIsInstance
(
special_tok_id
,
int
)
def
test_pretrained_tokenizers
(
self
):
self
.
check_tokenizer_from_pretrained
(
GPT2Tokenizer
)
if
__name__
==
"__main__"
:
unittest
.
main
()
pytorch_transformers/tests/tokenization_xlm_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
import
json
from
pytorch_transformers.tokenization_xlm
import
XLMTokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
class
XLMTokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
""" Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt """
vocab
=
[
"l"
,
"o"
,
"w"
,
"e"
,
"r"
,
"s"
,
"t"
,
"i"
,
"d"
,
"n"
,
"w</w>"
,
"r</w>"
,
"t</w>"
,
"lo"
,
"low"
,
"er</w>"
,
"low</w>"
,
"lowest</w>"
,
"newer</w>"
,
"wider</w>"
,
"<unk>"
]
vocab_tokens
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
merges
=
[
"l o 123"
,
"lo w 1456"
,
"e r</w> 1789"
,
""
]
with
TemporaryDirectory
()
as
tmpdirname
:
vocab_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merges_file
=
os
.
path
.
join
(
tmpdirname
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
"w"
)
as
fp
:
fp
.
write
(
json
.
dumps
(
vocab_tokens
))
with
open
(
merges_file
,
"w"
)
as
fp
:
fp
.
write
(
"
\n
"
.
join
(
merges
))
input_text
=
u
"lower newer"
output_text
=
u
"lower newer"
create_and_check_tokenizer_commons
(
self
,
input_text
,
output_text
,
XLMTokenizer
,
tmpdirname
)
tokenizer
=
XLMTokenizer
(
vocab_file
,
merges_file
)
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er</w>"
]
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
"<unk>"
]
input_bpe_tokens
=
[
14
,
15
,
20
]
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
if
__name__
==
'__main__'
:
unittest
.
main
()
pytorch_transformers/tests/tokenization_xlnet_test.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
os
import
unittest
from
pytorch_transformers.tokenization_xlnet
import
(
XLNetTokenizer
,
SPIECE_UNDERLINE
)
from
.tokenization_tests_commons
import
create_and_check_tokenizer_commons
,
TemporaryDirectory
SAMPLE_VOCAB
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'fixtures/test_sentencepiece.model'
)
class
XLNetTokenizationTest
(
unittest
.
TestCase
):
def
test_full_tokenizer
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
keep_accents
=
True
)
with
TemporaryDirectory
()
as
tmpdirname
:
tokenizer
.
save_pretrained
(
tmpdirname
)
input_text
=
u
"This is a test"
output_text
=
u
"This is a test"
create_and_check_tokenizer_commons
(
self
,
input_text
,
output_text
,
XLNetTokenizer
,
tmpdirname
)
tokens
=
tokenizer
.
tokenize
(
u
'This is a test'
)
self
.
assertListEqual
(
tokens
,
[
u
'▁This'
,
u
'▁is'
,
u
'▁a'
,
u
'▁t'
,
u
'est'
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
285
,
46
,
10
,
170
,
382
])
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
u
'9'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'é'
,
u
'.'
])
ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
ids
,
[
8
,
21
,
84
,
55
,
24
,
19
,
7
,
0
,
602
,
347
,
347
,
347
,
3
,
12
,
66
,
46
,
72
,
80
,
6
,
0
,
4
])
back_tokens
=
tokenizer
.
convert_ids_to_tokens
(
ids
)
self
.
assertListEqual
(
back_tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
u
'<unk>'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
's'
,
u
'<unk>'
,
u
'.'
])
def
test_tokenizer_lower
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
True
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
''
,
u
'i'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
u
'9'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
'se'
,
u
'.'
])
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
u
"▁he"
,
u
"ll"
,
u
"o"
])
def
test_tokenizer_no_lower
(
self
):
tokenizer
=
XLNetTokenizer
(
SAMPLE_VOCAB
,
do_lower_case
=
False
)
tokens
=
tokenizer
.
tokenize
(
u
"I was born in 92000, and this is falsé."
)
self
.
assertListEqual
(
tokens
,
[
SPIECE_UNDERLINE
+
u
'I'
,
SPIECE_UNDERLINE
+
u
'was'
,
SPIECE_UNDERLINE
+
u
'b'
,
u
'or'
,
u
'n'
,
SPIECE_UNDERLINE
+
u
'in'
,
SPIECE_UNDERLINE
+
u
''
,
u
'9'
,
u
'2'
,
u
'0'
,
u
'0'
,
u
'0'
,
u
','
,
SPIECE_UNDERLINE
+
u
'and'
,
SPIECE_UNDERLINE
+
u
'this'
,
SPIECE_UNDERLINE
+
u
'is'
,
SPIECE_UNDERLINE
+
u
'f'
,
u
'al'
,
u
'se'
,
u
'.'
])
if
__name__
==
'__main__'
:
unittest
.
main
()
pytorch_
pretrained_b
er
t
/tokenization.py
→
pytorch_
transform
er
s
/tokenization
_bert
.py
View file @
f31154cb
...
@@ -22,26 +22,32 @@ import os
...
@@ -22,26 +22,32 @@ import os
import
unicodedata
import
unicodedata
from
io
import
open
from
io
import
open
from
.
file
_utils
import
cached_path
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'vocab.txt'
}
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt"
,
'vocab_file'
:
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt"
,
{
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt"
,
'bert-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
'bert-base-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt"
,
'bert-base-german-cased'
:
"https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt"
,
'bert-large-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt"
,
'bert-base-multilingual-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt"
,
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-base-german-cased'
:
"https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt"
,
'bert-large-uncased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt"
,
'bert-large-cased-whole-word-masking'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt"
,
'bert-large-uncased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-large-cased-whole-word-masking-finetuned-squad'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt"
,
'bert-base-cased-finetuned-mrpc'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt"
,
}
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'bert-base-uncased'
:
512
,
'bert-base-uncased'
:
512
,
'bert-large-uncased'
:
512
,
'bert-large-uncased'
:
512
,
'bert-base-cased'
:
512
,
'bert-base-cased'
:
512
,
...
@@ -56,21 +62,15 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
...
@@ -56,21 +62,15 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad'
:
512
,
'bert-large-cased-whole-word-masking-finetuned-squad'
:
512
,
'bert-base-cased-finetuned-mrpc'
:
512
,
'bert-base-cased-finetuned-mrpc'
:
512
,
}
}
VOCAB_NAME
=
'vocab.txt'
def
load_vocab
(
vocab_file
):
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
vocab
=
collections
.
OrderedDict
()
index
=
0
with
open
(
vocab_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
with
open
(
vocab_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
while
True
:
tokens
=
reader
.
read
().
splitlines
()
token
=
reader
.
readline
()
for
index
,
token
in
enumerate
(
tokens
):
if
not
token
:
vocab
[
token
]
=
index
break
index
+=
1
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
return
vocab
...
@@ -83,25 +83,48 @@ def whitespace_tokenize(text):
...
@@ -83,25 +83,48 @@ def whitespace_tokenize(text):
return
tokens
return
tokens
class
BertTokenizer
(
object
):
class
BertTokenizer
(
PreTrainedTokenizer
):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
r
"""
Constructs a BertTokenizer.
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
max_len
=
None
,
do_basic_tokenize
=
True
,
:class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
never_split
=
(
"[UNK]"
,
"[SEP]"
,
"[PAD]"
,
"[CLS]"
,
"[MASK]"
)):
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
minimum of this value (if specified) and the underlying BERT model's sequence length.
never_split: List of tokens which will never be split during tokenization. Only has an effect when
do_wordpiece_only=False
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
do_basic_tokenize
=
True
,
never_split
=
None
,
unk_token
=
"[UNK]"
,
sep_token
=
"[SEP]"
,
pad_token
=
"[PAD]"
,
cls_token
=
"[CLS]"
,
mask_token
=
"[MASK]"
,
tokenize_chinese_chars
=
True
,
**
kwargs
):
"""Constructs a BertTokenizer.
"""Constructs a BertTokenizer.
Args:
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input
**do_lower_case**: (`optional`) boolean (default True)
Only has an effect when do_wordpiece_only=False
Whether to lower case the input
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
Only has an effect when do_basic_tokenize=True
max_len: An artificial maximum length to truncate tokenized sequences to;
**do_basic_tokenize**: (`optional`) boolean (default True)
Effective maximum length is always the minimum of this
Whether to do basic tokenization before wordpiece.
value (if specified) and the underlying BERT model's
**never_split**: (`optional`) list of string
sequence length.
List of tokens which will never be split during tokenization.
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_basic_tokenize=True
Only has an effect when do_wordpiece_only=False
**tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters.
This should likely be desactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
"""
"""
super
(
BertTokenizer
,
self
).
__init__
(
unk_token
=
unk_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
**
kwargs
)
if
not
os
.
path
.
isfile
(
vocab_file
):
if
not
os
.
path
.
isfile
(
vocab_file
):
raise
ValueError
(
raise
ValueError
(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...
@@ -111,46 +134,43 @@ class BertTokenizer(object):
...
@@ -111,46 +134,43 @@ class BertTokenizer(object):
[(
ids
,
tok
)
for
tok
,
ids
in
self
.
vocab
.
items
()])
[(
ids
,
tok
)
for
tok
,
ids
in
self
.
vocab
.
items
()])
self
.
do_basic_tokenize
=
do_basic_tokenize
self
.
do_basic_tokenize
=
do_basic_tokenize
if
do_basic_tokenize
:
if
do_basic_tokenize
:
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
,
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
,
never_split
=
never_split
)
never_split
=
never_split
,
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
tokenize_chinese_chars
=
tokenize_chinese_chars
)
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
,
unk_token
=
self
.
unk_token
)
def
tokenize
(
self
,
text
):
@
property
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
def
_tokenize
(
self
,
text
):
split_tokens
=
[]
split_tokens
=
[]
if
self
.
do_basic_tokenize
:
if
self
.
do_basic_tokenize
:
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
,
never_split
=
self
.
all_special_tokens
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
split_tokens
.
append
(
sub_token
)
else
:
else
:
split_tokens
=
self
.
wordpiece_tokenizer
.
tokenize
(
text
)
split_tokens
=
self
.
wordpiece_tokenizer
.
tokenize
(
text
)
return
split_tokens
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
def
_convert_token_to_id
(
self
,
token
):
"""Converts a sequence of tokens into ids using the vocab."""
""" Converts a token (str/unicode) in an id using the vocab. """
ids
=
[]
return
self
.
vocab
.
get
(
token
,
self
.
vocab
.
get
(
self
.
unk_token
))
for
token
in
tokens
:
ids
.
append
(
self
.
vocab
[
token
])
def
_convert_id_to_token
(
self
,
index
):
if
len
(
ids
)
>
self
.
max_len
:
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
logger
.
warning
(
return
self
.
ids_to_tokens
.
get
(
index
,
self
.
unk_token
)
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
def
convert_tokens_to_string
(
self
,
tokens
):
" sequence through BERT will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
""" Converts a sequence of tokens (string) in a single string. """
)
out_string
=
' '
.
join
(
tokens
).
replace
(
' ##'
,
''
).
strip
()
return
ids
return
out_string
def
convert_ids_to_tokens
(
self
,
ids
):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens
=
[]
for
i
in
ids
:
tokens
.
append
(
self
.
ids_to_tokens
[
i
])
return
tokens
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a directory or file."""
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
index
=
0
if
os
.
path
.
isdir
(
vocab_path
):
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'vocab_file'
]
)
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
if
index
!=
token_index
:
...
@@ -159,16 +179,13 @@ class BertTokenizer(object):
...
@@ -159,16 +179,13 @@ class BertTokenizer(object):
index
=
token_index
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
index
+=
1
return
vocab_file
return
(
vocab_file
,)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
"""
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
if
pretrained_model_name_or_path
in
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
if
'-cased'
in
pretrained_model_name_or_path
and
kwargs
.
get
(
'do_lower_case'
,
True
):
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
logger
.
warning
(
"The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
...
@@ -179,58 +196,44 @@ class BertTokenizer(object):
...
@@ -179,58 +196,44 @@ class BertTokenizer(object):
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior."
)
"but you may want to check this behavior."
)
kwargs
[
'do_lower_case'
]
=
True
kwargs
[
'do_lower_case'
]
=
True
else
:
vocab_file
=
pretrained_model_name_or_path
return
super
(
BertTokenizer
,
cls
).
_from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
if
os
.
path
.
isdir
(
vocab_file
):
vocab_file
=
os
.
path
.
join
(
vocab_file
,
VOCAB_NAME
)
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
class
BasicTokenizer
(
object
):
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
def
__init__
(
self
,
do_lower_case
=
True
,
never_split
=
None
,
tokenize_chinese_chars
=
True
):
do_lower_case
=
True
,
""" Constructs a BasicTokenizer.
never_split
=
(
"[UNK]"
,
"[SEP]"
,
"[PAD]"
,
"[CLS]"
,
"[MASK]"
)):
"""Constructs a BasicTokenizer.
Args:
Args:
do_lower_case: Whether to lower case the input.
**do_lower_case**: Whether to lower case the input.
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes.
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
List of token not to split.
**tokenize_chinese_chars**: (`optional`) boolean (default True)
Whether to tokenize Chinese characters.
This should likely be desactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
"""
"""
if
never_split
is
None
:
never_split
=
[]
self
.
do_lower_case
=
do_lower_case
self
.
do_lower_case
=
do_lower_case
self
.
never_split
=
never_split
self
.
never_split
=
never_split
self
.
tokenize_chinese_chars
=
tokenize_chinese_chars
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
,
never_split
=
None
):
"""Tokenizes a piece of text."""
""" Basic Tokenization of a piece of text.
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
Args:
**never_split**: (`optional`) list of str
Kept for backward compatibility purposes.
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
List of token not to split.
"""
never_split
=
self
.
never_split
+
(
never_split
if
never_split
is
not
None
else
[])
text
=
self
.
_clean_text
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# models. This is also applied to the English models now, but it doesn't
...
@@ -238,11 +241,12 @@ class BasicTokenizer(object):
...
@@ -238,11 +241,12 @@ class BasicTokenizer(object):
# and generally don't have any Chinese data in them (there are Chinese
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
if
self
.
tokenize_chinese_chars
:
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
split_tokens
=
[]
for
token
in
orig_tokens
:
for
token
in
orig_tokens
:
if
self
.
do_lower_case
and
token
not
in
self
.
never_split
:
if
self
.
do_lower_case
and
token
not
in
never_split
:
token
=
token
.
lower
()
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
...
@@ -261,9 +265,9 @@ class BasicTokenizer(object):
...
@@ -261,9 +265,9 @@ class BasicTokenizer(object):
output
.
append
(
char
)
output
.
append
(
char
)
return
""
.
join
(
output
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
def
_run_split_on_punc
(
self
,
text
,
never_split
=
None
):
"""Splits punctuation on a piece of text."""
"""Splits punctuation on a piece of text."""
if
text
in
self
.
never_split
:
if
never_split
is
not
None
and
text
in
never_split
:
return
[
text
]
return
[
text
]
chars
=
list
(
text
)
chars
=
list
(
text
)
i
=
0
i
=
0
...
@@ -335,7 +339,7 @@ class BasicTokenizer(object):
...
@@ -335,7 +339,7 @@ class BasicTokenizer(object):
class
WordpieceTokenizer
(
object
):
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenization."""
"""Runs WordPiece tokenization."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
def
__init__
(
self
,
vocab
,
unk_token
,
max_input_chars_per_word
=
100
):
self
.
vocab
=
vocab
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
self
.
max_input_chars_per_word
=
max_input_chars_per_word
...
...
pytorch_
pretrained_b
er
t
/tokenization_gpt2.py
→
pytorch_
transform
er
s
/tokenization_gpt2.py
View file @
f31154cb
...
@@ -31,24 +31,32 @@ except ImportError:
...
@@ -31,24 +31,32 @@ except ImportError:
def
lru_cache
():
def
lru_cache
():
return
lambda
func
:
func
return
lambda
func
:
func
from
.
file
_utils
import
cached_path
from
.
tokenization
_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'
gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-
vocab.json
"
,
'
vocab_file'
:
'
vocab.json
'
,
'
gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
'
merges_file'
:
'merges.txt'
,
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
'vocab_file'
:
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json"
,
},
'merges_file'
:
{
'gpt2'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
,
'gpt2-medium'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt"
,
},
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'gpt2'
:
1024
,
'gpt2'
:
1024
,
'gpt2-medium'
:
1024
,
}
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
@
lru_cache
()
@
lru_cache
()
def
bytes_to_unicode
():
def
bytes_to_unicode
():
...
@@ -85,71 +93,19 @@ def get_pairs(word):
...
@@ -85,71 +93,19 @@ def get_pairs(word):
prev_char
=
char
prev_char
=
char
return
pairs
return
pairs
class
GPT2Tokenizer
(
object
):
class
GPT2Tokenizer
(
PreTrainedTokenizer
):
"""
"""
GPT-2 BPE tokenizer. Peculiarities:
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
- Byte-level BPE
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed.
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
"""
bos_token
=
"<|endoftext|>"
,
eos_token
=
"<|endoftext|>"
,
**
kwargs
):
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
super
(
GPT2Tokenizer
,
self
).
__init__
(
bos_token
=
bos_token
,
eos_token
=
eos_token
,
**
kwargs
)
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
special_tokens
=
None
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
errors
=
errors
# how to handle errors in decoding
...
@@ -163,25 +119,9 @@ class GPT2Tokenizer(object):
...
@@ -163,25 +119,9 @@ class GPT2Tokenizer(object):
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
self
.
special_tokens
=
{}
@
property
self
.
special_tokens_decoder
=
{}
def
vocab_size
(
self
):
self
.
set_special_tokens
(
special_tokens
)
return
len
(
self
.
encoder
)
def
__len__
(
self
):
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if
not
special_tokens
:
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
if
token
in
self
.
cache
:
...
@@ -224,7 +164,7 @@ class GPT2Tokenizer(object):
...
@@ -224,7 +164,7 @@ class GPT2Tokenizer(object):
self
.
cache
[
token
]
=
word
self
.
cache
[
token
]
=
word
return
word
return
word
def
tokenize
(
self
,
text
):
def
_
tokenize
(
self
,
text
):
""" Tokenize a string. """
""" Tokenize a string. """
bpe_tokens
=
[]
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
for
token
in
re
.
findall
(
self
.
pat
,
text
):
...
@@ -235,59 +175,29 @@ class GPT2Tokenizer(object):
...
@@ -235,59 +175,29 @@ class GPT2Tokenizer(object):
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
return
bpe_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a sequence of tokens into ids using the vocab. """
""" Converts a token (str/unicode) in an id using the vocab. """
ids
=
[]
if
token
in
self
.
encoder
:
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
return
self
.
encoder
.
get
(
token
)
if
tokens
in
self
.
special_tokens
:
return
self
.
encoder
.
get
(
self
.
unk_token
)
return
self
.
special_tokens
[
tokens
]
else
:
return
self
.
encoder
.
get
(
tokens
,
0
)
for
token
in
tokens
:
if
token
in
self
.
special_tokens
:
ids
.
append
(
self
.
special_tokens
[
token
])
else
:
ids
.
append
(
self
.
encoder
.
get
(
token
,
0
))
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens
=
[]
for
i
in
ids
:
if
i
in
self
.
special_tokens_decoder
:
if
not
skip_special_tokens
:
tokens
.
append
(
self
.
special_tokens_decoder
[
i
])
else
:
tokens
.
append
(
self
.
decoder
[
i
])
return
tokens
def
encode
(
self
,
text
):
def
_convert_id_to_token
(
self
,
index
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return
self
.
decoder
.
get
(
index
)
def
decode
(
self
,
tokens
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
def
convert_tokens_to_string
(
self
,
tokens
):
text
=
''
.
join
(
self
.
convert_ids_to_tokens
(
tokens
,
skip_special_tokens
=
skip_special_tokens
))
""" Converts a sequence of tokens (string) in a single string. """
text
=
''
.
join
(
tokens
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
if
clean_up_tokenization_spaces
:
text
=
text
.
replace
(
'<unk>'
,
''
)
text
=
text
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
text
return
text
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
save_directory
):
"""Save the tokenizer vocabulary and merge files to a directory."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'merges_file'
])
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
...
@@ -303,14 +213,4 @@ class GPT2Tokenizer(object):
...
@@ -303,14 +213,4 @@ class GPT2Tokenizer(object):
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
+=
1
index
=
len
(
self
.
encoder
)
return
vocab_file
,
merge_file
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
,
special_tokens_file
pytorch_
pretrained_b
er
t
/tokenization_openai.py
→
pytorch_
transform
er
s
/tokenization_openai.py
View file @
f31154cb
...
@@ -20,28 +20,32 @@ import json
...
@@ -20,28 +20,32 @@ import json
import
logging
import
logging
import
os
import
os
import
re
import
re
import
sys
from
io
import
open
from
io
import
open
from
tqdm
import
tqdm
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_bert
import
BasicTokenizer
from
.file_utils
import
cached_path
from
.tokenization
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json"
,
'vocab_file'
:
'vocab.json'
,
'merges_file'
:
'merges.txt'
,
}
}
PRETRAINED_MERGES_ARCHIVE_MAP
=
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json"
,
},
'merges_file'
:
{
'openai-gpt'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt"
,
},
}
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'openai-gpt'
:
512
,
'openai-gpt'
:
512
,
}
}
VOCAB_NAME
=
'vocab.json'
MERGES_NAME
=
'merges.txt'
SPECIAL_TOKENS_NAME
=
'special_tokens.txt'
def
get_pairs
(
word
):
def
get_pairs
(
word
):
"""
"""
...
@@ -70,73 +74,19 @@ def text_standardize(text):
...
@@ -70,73 +74,19 @@ def text_standardize(text):
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
return
text
.
strip
()
return
text
.
strip
()
class
OpenAIGPTTokenizer
(
object
):
class
OpenAIGPTTokenizer
(
PreTrainedTokenizer
):
"""
"""
BPE tokenizer. Peculiarities:
BPE tokenizer. Peculiarities:
- lower case all inputs
- lower case all inputs
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
def
__init__
(
self
,
vocab_file
,
merges_file
,
unk_token
=
"<unk>"
,
**
kwargs
):
"""
super
(
OpenAIGPTTokenizer
,
self
).
__init__
(
unk_token
=
unk_token
,
**
kwargs
)
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
merges_file
=
PRETRAINED_MERGES_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
special_tokens_file
=
None
else
:
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
merges_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
MERGES_NAME
)
special_tokens_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
SPECIAL_TOKENS_NAME
)
if
not
os
.
path
.
exists
(
special_tokens_file
):
special_tokens_file
=
None
else
:
logger
.
info
(
"loading special tokens file {}"
.
format
(
special_tokens_file
))
# redirect to the cache, if necessary
try
:
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
resolved_merges_file
=
cached_path
(
merges_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
,
merges_file
))
return
None
if
resolved_vocab_file
==
vocab_file
and
resolved_merges_file
==
merges_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
logger
.
info
(
"loading merges file {}"
.
format
(
merges_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
logger
.
info
(
"loading merges file {} from cache at {}"
.
format
(
merges_file
,
resolved_merges_file
))
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name_or_path
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
if
special_tokens_file
and
'special_tokens'
not
in
kwargs
:
special_tokens
=
open
(
special_tokens_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
else
:
special_tokens
=
kwargs
.
pop
(
'special_tokens'
,
[])
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
special_tokens
=
special_tokens
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
special_tokens
=
None
,
max_len
=
None
):
try
:
try
:
import
ftfy
import
ftfy
import
spacy
import
spacy
...
@@ -144,39 +94,19 @@ class OpenAIGPTTokenizer(object):
...
@@ -144,39 +94,19 @@ class OpenAIGPTTokenizer(object):
self
.
fix_text
=
ftfy
.
fix_text
self
.
fix_text
=
ftfy
.
fix_text
except
ImportError
:
except
ImportError
:
logger
.
warning
(
"ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy."
)
logger
.
warning
(
"ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy."
)
self
.
nlp
=
BasicTokenizer
(
do_lower_case
=
True
,
self
.
nlp
=
BasicTokenizer
(
do_lower_case
=
True
)
never_split
=
special_tokens
if
special_tokens
is
not
None
else
[])
self
.
fix_text
=
None
self
.
fix_text
=
None
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
cache
=
{}
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
self
.
set_special_tokens
(
special_tokens
)
def
__len__
(
self
):
@
property
return
len
(
self
.
encoder
)
+
len
(
self
.
special_tokens
)
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
def
set_special_tokens
(
self
,
special_tokens
):
""" Add a list of additional tokens to the encoder.
The additional tokens are indexed starting from the last index of the
current vocabulary in the order of the `special_tokens` list.
"""
if
not
special_tokens
:
self
.
special_tokens
=
{}
self
.
special_tokens_decoder
=
{}
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
if
self
.
fix_text
is
None
:
# Using BERT's BasicTokenizer: we can update the tokenizer
self
.
nlp
.
never_split
=
special_tokens
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
'</w>'
,)
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
'</w>'
,)
...
@@ -221,7 +151,7 @@ class OpenAIGPTTokenizer(object):
...
@@ -221,7 +151,7 @@ class OpenAIGPTTokenizer(object):
self
.
cache
[
token
]
=
word
self
.
cache
[
token
]
=
word
return
word
return
word
def
tokenize
(
self
,
text
):
def
_
tokenize
(
self
,
text
):
""" Tokenize a string. """
""" Tokenize a string. """
split_tokens
=
[]
split_tokens
=
[]
if
self
.
fix_text
is
None
:
if
self
.
fix_text
is
None
:
...
@@ -236,60 +166,26 @@ class OpenAIGPTTokenizer(object):
...
@@ -236,60 +166,26 @@ class OpenAIGPTTokenizer(object):
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
.
text
.
lower
()).
split
(
' '
)])
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
.
text
.
lower
()).
split
(
' '
)])
return
split_tokens
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
def
_convert_token_to_id
(
self
,
token
):
""" Converts a sequence of tokens into ids using the vocab. """
""" Converts a token (str/unicode) in an id using the vocab. """
ids
=
[]
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
if
isinstance
(
tokens
,
str
)
or
(
sys
.
version_info
[
0
]
==
2
and
isinstance
(
tokens
,
unicode
)):
if
tokens
in
self
.
special_tokens
:
return
self
.
special_tokens
[
tokens
]
else
:
return
self
.
encoder
.
get
(
tokens
,
0
)
for
token
in
tokens
:
if
token
in
self
.
special_tokens
:
ids
.
append
(
self
.
special_tokens
[
token
])
else
:
ids
.
append
(
self
.
encoder
.
get
(
token
,
0
))
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens
=
[]
for
i
in
ids
:
if
i
in
self
.
special_tokens_decoder
:
if
not
skip_special_tokens
:
tokens
.
append
(
self
.
special_tokens_decoder
[
i
])
else
:
tokens
.
append
(
self
.
decoder
[
i
])
return
tokens
def
encode
(
self
,
text
):
def
_convert_id_to_token
(
self
,
index
):
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
"""Converts an id in a token (BPE) using the vocab."""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
decode
(
self
,
ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of ids in a string."""
""" Converts a sequence of tokens (string) in a single string. """
tokens
=
self
.
convert_ids_to_tokens
(
ids
,
skip_special_tokens
=
skip_special_tokens
)
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
if
clean_up_tokenization_spaces
:
out_string
=
out_string
.
replace
(
'<unk>'
,
''
)
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
return
out_string
def
save_vocabulary
(
self
,
vocab_path
):
def
save_vocabulary
(
self
,
save_directory
):
"""Save the tokenizer vocabulary and merge files to a directory."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
vocab_path
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
vocab_path
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
return
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_NAME
)
vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merge_file
=
os
.
path
.
join
(
vocab_path
,
MERGES_NAME
)
merge_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'merges_file'
])
special_tokens_file
=
os
.
path
.
join
(
vocab_path
,
SPECIAL_TOKENS_NAME
)
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
...
@@ -305,14 +201,4 @@ class OpenAIGPTTokenizer(object):
...
@@ -305,14 +201,4 @@ class OpenAIGPTTokenizer(object):
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
index
+=
1
index
=
len
(
self
.
encoder
)
return
vocab_file
,
merge_file
with
open
(
special_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
special_tokens
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving special tokens vocabulary to {}: BPE indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
special_tokens_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
,
special_tokens_file
pytorch_
pretrained_b
er
t
/tokenization_transfo_xl.py
→
pytorch_
transform
er
s
/tokenization_transfo_xl.py
View file @
f31154cb
...
@@ -25,12 +25,12 @@ import os
...
@@ -25,12 +25,12 @@ import os
import
sys
import
sys
from
collections
import
Counter
,
OrderedDict
from
collections
import
Counter
,
OrderedDict
from
io
import
open
from
io
import
open
import
unicodedata
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
.file_utils
import
cached_path
from
.file_utils
import
cached_path
from
.tokenization_utils
import
PreTrainedTokenizer
,
clean_up_tokenization
if
sys
.
version_info
[
0
]
==
2
:
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
import
cPickle
as
pickle
...
@@ -40,66 +40,43 @@ else:
...
@@ -40,66 +40,43 @@ else:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
PRETRAINED_VOCAB_ARCHIVE_MAP
=
{
VOCAB_FILES_NAMES
=
{
'pretrained_vocab_file'
:
'vocab.bin'
,
'vocab_file'
:
'vocab.txt'
}
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin"
,
PRETRAINED_VOCAB_FILES_MAP
=
{
'pretrained_vocab_file'
:
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin"
,
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'transfo-xl-wt103'
:
None
,
}
}
VOCAB_NAME
=
'vocab.bin'
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
PRETRAINED_CORPUS_ARCHIVE_MAP
=
{
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin"
,
'transfo-xl-wt103'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin"
,
}
}
CORPUS_NAME
=
'corpus.bin'
CORPUS_NAME
=
'corpus.bin'
class
TransfoXLTokenizer
(
object
):
class
TransfoXLTokenizer
(
PreTrainedTokenizer
):
"""
"""
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
"""
"""
@
classmethod
vocab_files_names
=
VOCAB_FILES_NAMES
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
"""
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
Instantiate a TransfoXLTokenizer.
The TransfoXLTokenizer.
def
__init__
(
self
,
special
=
None
,
min_freq
=
0
,
max_size
=
None
,
lower_case
=
False
,
"""
delimiter
=
None
,
vocab_file
=
None
,
pretrained_vocab_file
=
None
,
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
never_split
=
None
,
unk_token
=
"<unk>"
,
eos_token
=
"<eos>"
,
vocab_file
=
PRETRAINED_VOCAB_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
additional_special_tokens
=
[
"<formula>"
],
**
kwargs
):
else
:
super
(
TransfoXLTokenizer
,
self
).
__init__
(
unk_token
=
unk_token
,
eos_token
=
eos_token
,
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
additional_special_tokens
=
additional_special_tokens
,
vocab_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
VOCAB_NAME
)
**
kwargs
)
else
:
if
never_split
is
None
:
vocab_file
=
pretrained_model_name_or_path
never_split
=
self
.
all_special_tokens
# redirect to the cache, if necessary
if
special
is
None
:
try
:
special
=
[]
resolved_vocab_file
=
cached_path
(
vocab_file
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
PRETRAINED_VOCAB_ARCHIVE_MAP
:
logger
.
error
(
"Couldn't reach server at '{}' to download vocabulary."
.
format
(
vocab_file
))
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_VOCAB_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
vocab_file
))
return
None
if
resolved_vocab_file
==
vocab_file
:
logger
.
info
(
"loading vocabulary file {}"
.
format
(
vocab_file
))
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
# Instantiate tokenizer.
tokenizer
=
cls
(
*
inputs
,
**
kwargs
)
vocab_dict
=
torch
.
load
(
resolved_vocab_file
)
for
key
,
value
in
vocab_dict
.
items
():
tokenizer
.
__dict__
[
key
]
=
value
return
tokenizer
def
__init__
(
self
,
special
=
[],
min_freq
=
0
,
max_size
=
None
,
lower_case
=
False
,
delimiter
=
None
,
vocab_file
=
None
,
never_split
=
(
"<unk>"
,
"<eos>"
,
"<formula>"
)):
self
.
counter
=
Counter
()
self
.
counter
=
Counter
()
self
.
special
=
special
self
.
special
=
special
self
.
min_freq
=
min_freq
self
.
min_freq
=
min_freq
...
@@ -109,15 +86,25 @@ class TransfoXLTokenizer(object):
...
@@ -109,15 +86,25 @@ class TransfoXLTokenizer(object):
self
.
vocab_file
=
vocab_file
self
.
vocab_file
=
vocab_file
self
.
never_split
=
never_split
self
.
never_split
=
never_split
if
pretrained_vocab_file
is
not
None
:
# Hack because, honestly this tokenizer was not made to be used
# in a library like ours, at all.
vocab_dict
=
torch
.
load
(
pretrained_vocab_file
)
for
key
,
value
in
vocab_dict
.
items
():
self
.
__dict__
[
key
]
=
value
if
vocab_file
is
not
None
:
self
.
build_vocab
()
def
count_file
(
self
,
path
,
verbose
=
False
,
add_eos
=
False
):
def
count_file
(
self
,
path
,
verbose
=
False
,
add_eos
=
False
):
if
verbose
:
print
(
'counting file {} ...'
.
format
(
path
))
if
verbose
:
logger
.
info
(
'counting file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
assert
os
.
path
.
exists
(
path
)
sents
=
[]
sents
=
[]
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
)
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
)
self
.
counter
.
update
(
symbols
)
self
.
counter
.
update
(
symbols
)
sents
.
append
(
symbols
)
sents
.
append
(
symbols
)
...
@@ -128,10 +115,10 @@ class TransfoXLTokenizer(object):
...
@@ -128,10 +115,10 @@ class TransfoXLTokenizer(object):
"""
"""
sents : a list of sentences, each a list of tokenized symbols
sents : a list of sentences, each a list of tokenized symbols
"""
"""
if
verbose
:
print
(
'counting {} sents ...'
.
format
(
len
(
sents
)))
if
verbose
:
logger
.
info
(
'counting {} sents ...'
.
format
(
len
(
sents
)))
for
idx
,
symbols
in
enumerate
(
sents
):
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
self
.
counter
.
update
(
symbols
)
self
.
counter
.
update
(
symbols
)
def
_build_from_file
(
self
,
vocab_file
):
def
_build_from_file
(
self
,
vocab_file
):
...
@@ -153,17 +140,17 @@ class TransfoXLTokenizer(object):
...
@@ -153,17 +140,17 @@ class TransfoXLTokenizer(object):
"""Save the tokenizer vocabulary to a directory or file."""
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
index
=
0
if
os
.
path
.
isdir
(
vocab_path
):
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
NAME
)
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_
FILES_NAMES
[
'pretrained_vocab_file'
]
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
torch
.
save
(
self
.
__dict__
,
vocab_file
)
return
vocab_file
return
(
vocab_file
,)
def
build_vocab
(
self
):
def
build_vocab
(
self
):
if
self
.
vocab_file
:
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
logger
.
info
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
self
.
_build_from_file
(
self
.
vocab_file
)
self
.
_build_from_file
(
self
.
vocab_file
)
print
(
'final vocab size {}'
.
format
(
len
(
self
)))
logger
.
info
(
'final vocab size {}'
.
format
(
len
(
self
)))
else
:
else
:
print
(
'building vocab with min_freq={}, max_size={}'
.
format
(
logger
.
info
(
'building vocab with min_freq={}, max_size={}'
.
format
(
self
.
min_freq
,
self
.
max_size
))
self
.
min_freq
,
self
.
max_size
))
self
.
idx2sym
=
[]
self
.
idx2sym
=
[]
self
.
sym2idx
=
OrderedDict
()
self
.
sym2idx
=
OrderedDict
()
...
@@ -175,18 +162,18 @@ class TransfoXLTokenizer(object):
...
@@ -175,18 +162,18 @@ class TransfoXLTokenizer(object):
if
cnt
<
self
.
min_freq
:
break
if
cnt
<
self
.
min_freq
:
break
self
.
add_symbol
(
sym
)
self
.
add_symbol
(
sym
)
print
(
'final vocab size {} from {} unique tokens'
.
format
(
logger
.
info
(
'final vocab size {} from {} unique tokens'
.
format
(
len
(
self
),
len
(
self
.
counter
)))
len
(
self
),
len
(
self
.
counter
)))
def
encode_file
(
self
,
path
,
ordered
=
False
,
verbose
=
False
,
add_eos
=
True
,
def
encode_file
(
self
,
path
,
ordered
=
False
,
verbose
=
False
,
add_eos
=
True
,
add_double_eos
=
False
):
add_double_eos
=
False
):
if
verbose
:
print
(
'encoding file {} ...'
.
format
(
path
))
if
verbose
:
logger
.
info
(
'encoding file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
assert
os
.
path
.
exists
(
path
)
encoded
=
[]
encoded
=
[]
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
,
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
,
add_double_eos
=
add_double_eos
)
add_double_eos
=
add_double_eos
)
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
...
@@ -197,11 +184,11 @@ class TransfoXLTokenizer(object):
...
@@ -197,11 +184,11 @@ class TransfoXLTokenizer(object):
return
encoded
return
encoded
def
encode_sents
(
self
,
sents
,
ordered
=
False
,
verbose
=
False
):
def
encode_sents
(
self
,
sents
,
ordered
=
False
,
verbose
=
False
):
if
verbose
:
print
(
'encoding {} sents ...'
.
format
(
len
(
sents
)))
if
verbose
:
logger
.
info
(
'encoding {} sents ...'
.
format
(
len
(
sents
)))
encoded
=
[]
encoded
=
[]
for
idx
,
symbols
in
enumerate
(
sents
):
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
if
ordered
:
if
ordered
:
...
@@ -220,15 +207,17 @@ class TransfoXLTokenizer(object):
...
@@ -220,15 +207,17 @@ class TransfoXLTokenizer(object):
self
.
idx2sym
.
append
(
sym
)
self
.
idx2sym
.
append
(
sym
)
self
.
sym2idx
[
sym
]
=
len
(
self
.
idx2sym
)
-
1
self
.
sym2idx
[
sym
]
=
len
(
self
.
idx2sym
)
-
1
def
get_sym
(
self
,
idx
):
def
_convert_id_to_token
(
self
,
idx
):
"""Converts an id in a token (BPE) using the vocab."""
assert
0
<=
idx
<
len
(
self
),
'Index {} out of vocabulary range'
.
format
(
idx
)
assert
0
<=
idx
<
len
(
self
),
'Index {} out of vocabulary range'
.
format
(
idx
)
return
self
.
idx2sym
[
idx
]
return
self
.
idx2sym
[
idx
]
def
get_idx
(
self
,
sym
):
def
_convert_token_to_id
(
self
,
sym
):
""" Converts a token (str/unicode) in an id using the vocab. """
if
sym
in
self
.
sym2idx
:
if
sym
in
self
.
sym2idx
:
return
self
.
sym2idx
[
sym
]
return
self
.
sym2idx
[
sym
]
else
:
else
:
#
print
('encounter unk {}'.format(sym))
#
logger.info
('encounter unk {}'.format(sym))
# assert '<eos>' not in sym
# assert '<eos>' not in sym
if
hasattr
(
self
,
'unk_idx'
):
if
hasattr
(
self
,
'unk_idx'
):
return
self
.
sym2idx
.
get
(
sym
,
self
.
unk_idx
)
return
self
.
sym2idx
.
get
(
sym
,
self
.
unk_idx
)
...
@@ -240,28 +229,19 @@ class TransfoXLTokenizer(object):
...
@@ -240,28 +229,19 @@ class TransfoXLTokenizer(object):
else
:
else
:
raise
ValueError
(
'Token not in vocabulary and no <unk> token in vocabulary for replacement'
)
raise
ValueError
(
'Token not in vocabulary and no <unk> token in vocabulary for replacement'
)
def
convert_ids_to_tokens
(
self
,
indices
):
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of indices in symbols using the vocab."""
""" Converts a sequence of tokens (string) in a single string. """
return
[
self
.
get_sym
(
idx
)
for
idx
in
indices
]
out_string
=
' '
.
join
(
tokens
).
strip
()
return
out_string
def
convert_tokens_to_ids
(
self
,
symbols
):
"""Converts a sequence of symbols into ids using the vocab."""
return
[
self
.
get_idx
(
sym
)
for
sym
in
symbols
]
def
convert_to_tensor
(
self
,
symbols
):
def
convert_to_tensor
(
self
,
symbols
):
return
torch
.
LongTensor
(
self
.
convert_tokens_to_ids
(
symbols
))
return
torch
.
LongTensor
(
self
.
convert_tokens_to_ids
(
symbols
))
def
decode
(
self
,
indices
,
exclude
=
None
):
@
property
"""Converts a sequence of indices in a string."""
def
vocab_size
(
self
):
if
exclude
is
None
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
])
else
:
return
' '
.
join
([
self
.
get_sym
(
idx
)
for
idx
in
indices
if
idx
not
in
exclude
])
def
__len__
(
self
):
return
len
(
self
.
idx2sym
)
return
len
(
self
.
idx2sym
)
def
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
def
_
tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
line
.
strip
()
line
=
line
.
strip
()
# convert to lower case
# convert to lower case
if
self
.
lower_case
:
if
self
.
lower_case
:
...
@@ -472,7 +452,7 @@ class TransfoXLCorpus(object):
...
@@ -472,7 +452,7 @@ class TransfoXLCorpus(object):
"We assumed '{}' was a path or url but couldn't find files {} "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
', '
.
join
(
PRETRAINED_
VOCAB
_ARCHIVE_MAP
.
keys
()),
', '
.
join
(
PRETRAINED_
CORPUS
_ARCHIVE_MAP
.
keys
()),
pretrained_model_name_or_path
,
pretrained_model_name_or_path
,
corpus_file
))
corpus_file
))
return
None
return
None
...
@@ -563,14 +543,14 @@ def get_lm_corpus(datadir, dataset):
...
@@ -563,14 +543,14 @@ def get_lm_corpus(datadir, dataset):
fn
=
os
.
path
.
join
(
datadir
,
'cache.pt'
)
fn
=
os
.
path
.
join
(
datadir
,
'cache.pt'
)
fn_pickle
=
os
.
path
.
join
(
datadir
,
'cache.pkl'
)
fn_pickle
=
os
.
path
.
join
(
datadir
,
'cache.pkl'
)
if
os
.
path
.
exists
(
fn
):
if
os
.
path
.
exists
(
fn
):
print
(
'Loading cached dataset...'
)
logger
.
info
(
'Loading cached dataset...'
)
corpus
=
torch
.
load
(
fn_pickle
)
corpus
=
torch
.
load
(
fn_pickle
)
elif
os
.
path
.
exists
(
fn
):
elif
os
.
path
.
exists
(
fn
):
print
(
'Loading cached dataset from pickle...'
)
logger
.
info
(
'Loading cached dataset from pickle...'
)
with
open
(
fn
,
"rb"
)
as
fp
:
with
open
(
fn
,
"rb"
)
as
fp
:
corpus
=
pickle
.
load
(
fp
)
corpus
=
pickle
.
load
(
fp
)
else
:
else
:
print
(
'Producing dataset {}...'
.
format
(
dataset
))
logger
.
info
(
'Producing dataset {}...'
.
format
(
dataset
))
kwargs
=
{}
kwargs
=
{}
if
dataset
in
[
'wt103'
,
'wt2'
]:
if
dataset
in
[
'wt103'
,
'wt2'
]:
kwargs
[
'special'
]
=
[
'<eos>'
]
kwargs
[
'special'
]
=
[
'<eos>'
]
...
...
pytorch_transformers/tokenization_utils.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
logging
import
os
import
json
import
six
from
io
import
open
from
.file_utils
import
cached_path
logger
=
logging
.
getLogger
(
__name__
)
SPECIAL_TOKENS_MAP_FILE
=
'special_tokens_map.json'
ADDED_TOKENS_FILE
=
'added_tokens.json'
class
PreTrainedTokenizer
(
object
):
""" An abstract class to handle dowloading and loading pretrained tokenizers and adding tokens to the vocabulary.
Derived class can set up a few special tokens to be used in common scripts and internals:
bos_token, eos_token, EOP_TOKEN, EOD_TOKEN, unk_token, sep_token, pad_token, cls_token, mask_token
additional_special_tokens = []
We defined an added_tokens_encoder to add new tokens to the vocabulary without having to handle the
specific vocabulary augmentation methods of the various underlying dictionnary structures (BPE, sentencepiece...).
"""
vocab_files_names
=
{}
pretrained_vocab_files_map
=
{}
max_model_input_sizes
=
{}
SPECIAL_TOKENS_ATTRIBUTES
=
[
"bos_token"
,
"eos_token"
,
"unk_token"
,
"sep_token"
,
"pad_token"
,
"cls_token"
,
"mask_token"
,
"additional_special_tokens"
]
@
property
def
bos_token
(
self
):
if
self
.
_bos_token
is
None
:
logger
.
error
(
"Using bos_token, but it is not set yet."
)
return
self
.
_bos_token
@
property
def
eos_token
(
self
):
if
self
.
_eos_token
is
None
:
logger
.
error
(
"Using eos_token, but it is not set yet."
)
return
self
.
_eos_token
@
property
def
unk_token
(
self
):
if
self
.
_unk_token
is
None
:
logger
.
error
(
"Using unk_token, but it is not set yet."
)
return
self
.
_unk_token
@
property
def
sep_token
(
self
):
if
self
.
_sep_token
is
None
:
logger
.
error
(
"Using sep_token, but it is not set yet."
)
return
self
.
_sep_token
@
property
def
pad_token
(
self
):
if
self
.
_pad_token
is
None
:
logger
.
error
(
"Using pad_token, but it is not set yet."
)
return
self
.
_pad_token
@
property
def
cls_token
(
self
):
if
self
.
_cls_token
is
None
:
logger
.
error
(
"Using cls_token, but it is not set yet."
)
return
self
.
_cls_token
@
property
def
mask_token
(
self
):
if
self
.
_mask_token
is
None
:
logger
.
error
(
"Using mask_token, but it is not set yet."
)
return
self
.
_mask_token
@
property
def
additional_special_tokens
(
self
):
if
self
.
_additional_special_tokens
is
None
:
logger
.
error
(
"Using additional_special_tokens, but it is not set yet."
)
return
self
.
_additional_special_tokens
@
bos_token
.
setter
def
bos_token
(
self
,
value
):
self
.
_bos_token
=
value
@
eos_token
.
setter
def
eos_token
(
self
,
value
):
self
.
_eos_token
=
value
@
unk_token
.
setter
def
unk_token
(
self
,
value
):
self
.
_unk_token
=
value
@
sep_token
.
setter
def
sep_token
(
self
,
value
):
self
.
_sep_token
=
value
@
pad_token
.
setter
def
pad_token
(
self
,
value
):
self
.
_pad_token
=
value
@
cls_token
.
setter
def
cls_token
(
self
,
value
):
self
.
_cls_token
=
value
@
mask_token
.
setter
def
mask_token
(
self
,
value
):
self
.
_mask_token
=
value
@
additional_special_tokens
.
setter
def
additional_special_tokens
(
self
,
value
):
self
.
_additional_special_tokens
=
value
def
__init__
(
self
,
max_len
=
None
,
**
kwargs
):
self
.
_bos_token
=
None
self
.
_eos_token
=
None
self
.
_unk_token
=
None
self
.
_sep_token
=
None
self
.
_pad_token
=
None
self
.
_cls_token
=
None
self
.
_mask_token
=
None
self
.
_additional_special_tokens
=
[]
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
added_tokens_encoder
=
{}
self
.
added_tokens_decoder
=
{}
for
key
,
value
in
kwargs
.
items
():
if
key
in
self
.
SPECIAL_TOKENS_ATTRIBUTES
:
setattr
(
self
,
key
,
value
)
@
classmethod
def
from_pretrained
(
cls
,
*
inputs
,
**
kwargs
):
return
cls
.
_from_pretrained
(
*
inputs
,
**
kwargs
)
@
classmethod
def
_from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
"""
Instantiate a PreTrainedTokenizer from pre-trained vocabulary files.
Download and cache the vocabulary files if needed.
"""
s3_models
=
list
(
cls
.
max_model_input_sizes
.
keys
())
vocab_files
=
{}
if
pretrained_model_name_or_path
in
s3_models
:
for
file_id
,
map_list
in
cls
.
pretrained_vocab_files_map
.
items
():
vocab_files
[
file_id
]
=
map_list
[
pretrained_model_name_or_path
]
else
:
all_vocab_files_names
=
{
'added_tokens_file'
:
ADDED_TOKENS_FILE
,
'special_tokens_map_file'
:
SPECIAL_TOKENS_MAP_FILE
}
all_vocab_files_names
.
update
(
cls
.
vocab_files_names
)
for
file_id
,
file_name
in
all_vocab_files_names
.
items
():
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
):
full_file_name
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
file_name
)
else
:
full_file_name
=
pretrained_model_name_or_path
if
not
os
.
path
.
exists
(
full_file_name
):
logger
.
info
(
"Didn't find file {}. We won't load it."
.
format
(
full_file_name
))
full_file_name
=
None
vocab_files
[
file_id
]
=
full_file_name
# Get files from url, cache, or disk depending on the case
try
:
resolved_vocab_files
=
{}
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
is
None
:
resolved_vocab_files
[
file_id
]
=
None
else
:
resolved_vocab_files
[
file_id
]
=
cached_path
(
file_path
,
cache_dir
=
cache_dir
)
except
EnvironmentError
:
if
pretrained_model_name_or_path
in
s3_models
:
logger
.
error
(
"Couldn't reach server to download vocabulary."
)
else
:
logger
.
error
(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url."
.
format
(
pretrained_model_name_or_path
,
', '
.
join
(
s3_models
),
pretrained_model_name_or_path
,
str
(
vocab_files
.
keys
())))
return
None
for
file_id
,
file_path
in
vocab_files
.
items
():
if
file_path
==
resolved_vocab_files
[
file_id
]:
logger
.
info
(
"loading file {}"
.
format
(
file_path
))
else
:
logger
.
info
(
"loading file {} from cache at {}"
.
format
(
file_path
,
resolved_vocab_files
[
file_id
]))
# Set max length if needed
if
pretrained_model_name_or_path
in
cls
.
max_model_input_sizes
:
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len
=
cls
.
max_model_input_sizes
[
pretrained_model_name_or_path
]
if
max_len
is
not
None
and
isinstance
(
max_len
,
(
int
,
float
)):
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Merge resolved_vocab_files arguments in kwargs.
added_tokens_file
=
resolved_vocab_files
.
pop
(
'added_tokens_file'
,
None
)
special_tokens_map_file
=
resolved_vocab_files
.
pop
(
'special_tokens_map_file'
,
None
)
for
args_name
,
file_path
in
resolved_vocab_files
.
items
():
if
args_name
not
in
kwargs
:
kwargs
[
args_name
]
=
file_path
if
special_tokens_map_file
is
not
None
:
special_tokens_map
=
json
.
load
(
open
(
special_tokens_map_file
,
encoding
=
"utf-8"
))
for
key
,
value
in
special_tokens_map
.
items
():
if
key
not
in
kwargs
:
kwargs
[
key
]
=
value
# Instantiate tokenizer.
tokenizer
=
cls
(
*
inputs
,
**
kwargs
)
# Add supplementary tokens.
if
added_tokens_file
is
not
None
:
added_tok_encoder
=
json
.
load
(
open
(
added_tokens_file
,
encoding
=
"utf-8"
))
added_tok_decoder
=
{
v
:
k
for
k
,
v
in
added_tok_encoder
.
items
()}
tokenizer
.
added_tokens_encoder
.
update
(
added_tok_encoder
)
tokenizer
.
added_tokens_decoder
.
update
(
added_tok_decoder
)
return
tokenizer
def
save_pretrained
(
self
,
save_directory
):
""" Save the tokenizer vocabulary files (with added tokens) and the
special-tokens-to-class-attributes-mapping to a directory, so that it
can be re-loaded using the `from_pretrained(save_directory)` class method.
"""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Saving directory ({}) should be a directory"
.
format
(
save_directory
))
return
special_tokens_map_file
=
os
.
path
.
join
(
save_directory
,
SPECIAL_TOKENS_MAP_FILE
)
added_tokens_file
=
os
.
path
.
join
(
save_directory
,
ADDED_TOKENS_FILE
)
with
open
(
special_tokens_map_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
special_tokens_map
,
ensure_ascii
=
False
))
with
open
(
added_tokens_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
if
self
.
added_tokens_encoder
:
out_str
=
json
.
dumps
(
self
.
added_tokens_decoder
,
ensure_ascii
=
False
)
else
:
out_str
=
u
"{}"
f
.
write
(
out_str
)
vocab_files
=
self
.
save_vocabulary
(
save_directory
)
return
vocab_files
+
(
special_tokens_map_file
,
added_tokens_file
)
def
save_vocabulary
(
self
,
save_directory
):
""" Save the tokenizer vocabulary to a directory. This method doesn't save added tokens
and special token mappings.
Please use `save_pretrained()` to save the full Tokenizer state so that it can be
reloaded using the `from_pretrained(save_directory)` class method.
"""
raise
NotImplementedError
def
vocab_size
(
self
):
raise
NotImplementedError
def
__len__
(
self
):
return
self
.
vocab_size
+
len
(
self
.
added_tokens_encoder
)
def
add_tokens
(
self
,
new_tokens
):
""" Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to the added_tokens_encoder with indices starting from
the last index of the current vocabulary.
Returns:
Number of tokens added to the vocabulary which can be used to correspondingly
increase the size of the associated model embedding matrices.
"""
if
not
new_tokens
:
return
0
to_add_tokens
=
[]
for
token
in
new_tokens
:
if
self
.
convert_tokens_to_ids
(
token
)
==
self
.
convert_tokens_to_ids
(
self
.
unk_token
):
to_add_tokens
.
append
(
token
)
logger
.
info
(
"Adding %s to the vocabulary"
,
token
)
added_tok_encoder
=
dict
((
tok
,
len
(
self
)
+
i
)
for
i
,
tok
in
enumerate
(
to_add_tokens
))
added_tok_decoder
=
{
v
:
k
for
k
,
v
in
added_tok_encoder
.
items
()}
self
.
added_tokens_encoder
.
update
(
added_tok_encoder
)
self
.
added_tokens_decoder
.
update
(
added_tok_decoder
)
return
len
(
to_add_tokens
)
def
add_special_tokens
(
self
,
special_tokens_dict
):
""" Add a dictionnary of special tokens (eos, pad, cls...) to the encoder and link them
to class attributes. If the special tokens are not in the vocabulary, they are added
to it and indexed starting from the last index of the current vocabulary.
Returns:
Number of tokens added to the vocabulary which can be used to correspondingly
increase the size of the associated model embedding matrices.
"""
if
not
special_tokens_dict
:
return
0
added_special_tokens
=
self
.
add_tokens
(
special_tokens_dict
.
values
())
for
key
,
value
in
special_tokens_dict
.
items
():
logger
.
info
(
"Assigning %s to the %s key of the tokenizer"
,
value
,
key
)
setattr
(
self
,
key
,
value
)
return
added_special_tokens
def
tokenize
(
self
,
text
,
**
kwargs
):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
Take care of added tokens.
"""
def
split_on_tokens
(
tok_list
,
text
):
if
not
text
:
return
[]
if
not
tok_list
:
return
self
.
_tokenize
(
text
,
**
kwargs
)
tok
=
tok_list
[
0
]
split_text
=
text
.
split
(
tok
)
return
sum
((
split_on_tokens
(
tok_list
[
1
:],
sub_text
.
strip
())
+
[
tok
]
\
for
sub_text
in
split_text
),
[])[:
-
1
]
added_tokens
=
list
(
self
.
added_tokens_encoder
.
keys
())
+
self
.
all_special_tokens
tokenized_text
=
split_on_tokens
(
added_tokens
,
text
)
return
tokenized_text
def
_tokenize
(
self
,
text
,
**
kwargs
):
""" Converts a string in a sequence of tokens (string), using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based
vocabularies (BPE/SentencePieces/WordPieces).
Don't take care of added tokens.
"""
raise
NotImplementedError
def
convert_tokens_to_ids
(
self
,
tokens
):
""" Converts a single token or a sequence of tokens (str/unicode) in a integer id
(resp.) a sequence of ids, using the vocabulary.
"""
if
isinstance
(
tokens
,
str
)
or
(
six
.
PY2
and
isinstance
(
tokens
,
unicode
)):
return
self
.
_convert_token_to_id_with_added_voc
(
tokens
)
ids
=
[]
for
token
in
tokens
:
ids
.
append
(
self
.
_convert_token_to_id_with_added_voc
(
token
))
if
len
(
ids
)
>
self
.
max_len
:
logger
.
warning
(
"Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
))
return
ids
def
_convert_token_to_id_with_added_voc
(
self
,
token
):
if
token
in
self
.
added_tokens_encoder
:
return
self
.
added_tokens_encoder
[
token
]
return
self
.
_convert_token_to_id
(
token
)
def
_convert_token_to_id
(
self
,
token
):
raise
NotImplementedError
def
encode
(
self
,
text
):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
"""
return
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
))
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
Args:
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
"""
if
isinstance
(
ids
,
int
):
if
ids
in
self
.
added_tokens_decoder
:
return
self
.
added_tokens_decoder
[
ids
]
else
:
return
self
.
_convert_id_to_token
(
ids
)
tokens
=
[]
for
index
in
ids
:
if
index
in
self
.
all_special_ids
and
skip_special_tokens
:
continue
if
index
in
self
.
added_tokens_decoder
:
tokens
.
append
(
self
.
added_tokens_decoder
[
index
])
else
:
tokens
.
append
(
self
.
_convert_id_to_token
(
index
))
return
tokens
def
_convert_id_to_token
(
self
,
index
):
raise
NotImplementedError
def
convert_tokens_to_string
(
self
,
tokens
):
""" Converts a sequence of tokens (string) in a single string.
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
but we often want to remove sub-word tokenization artifacts at the same time.
"""
return
' '
.
join
(
self
.
convert_ids_to_tokens
(
tokens
))
def
decode
(
self
,
token_ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
"""
filtered_tokens
=
self
.
convert_ids_to_tokens
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
)
text
=
self
.
convert_tokens_to_string
(
filtered_tokens
)
if
clean_up_tokenization_spaces
:
text
=
clean_up_tokenization
(
text
)
return
text
@
property
def
special_tokens_map
(
self
):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
values ('<unk>', '<cls>'...)
"""
set_attr
=
{}
for
attr
in
self
.
SPECIAL_TOKENS_ATTRIBUTES
:
attr_value
=
getattr
(
self
,
"_"
+
attr
)
if
attr_value
:
set_attr
[
attr
]
=
attr_value
return
set_attr
@
property
def
all_special_tokens
(
self
):
""" List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
(cls_token, unk_token...).
"""
all_toks
=
[]
set_attr
=
self
.
special_tokens_map
for
attr_value
in
set_attr
.
values
():
all_toks
=
all_toks
+
(
attr_value
if
isinstance
(
attr_value
,
(
list
,
tuple
))
else
[
attr_value
])
all_toks
=
list
(
set
(
all_toks
))
return
all_toks
@
property
def
all_special_ids
(
self
):
""" List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
class attributes (cls_token, unk_token...).
"""
all_toks
=
self
.
all_special_tokens
all_ids
=
list
(
self
.
convert_tokens_to_ids
(
t
)
for
t
in
all_toks
)
return
all_ids
def
clean_up_tokenization
(
out_string
):
out_string
=
out_string
.
replace
(
' .'
,
'.'
).
replace
(
' ?'
,
'?'
).
replace
(
' !'
,
'!'
).
replace
(
' ,'
,
','
).
replace
(
" ' "
,
"'"
).
replace
(
" n't"
,
"n't"
).
replace
(
" 'm"
,
"'m"
).
replace
(
" do not"
,
" don't"
).
replace
(
" 's"
,
"'s"
).
replace
(
" 've"
,
"'ve"
).
replace
(
" 're"
,
"'re"
)
return
out_string
pytorch_transformers/tokenization_xlm.py
0 → 100644
View file @
f31154cb
# coding=utf-8
# Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
import
json
import
logging
import
os
import
re
from
io
import
open
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_bert
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'vocab.json'
,
'merges_file'
:
'merges.txt'
,
}
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json"
,
'xlm-mlm-ende-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json"
,
'xlm-mlm-enfr-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json"
,
'xlm-mlm-enro-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json"
,
'xlm-mlm-tlm-xnli15-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json"
,
'xlm-mlm-xnli15-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json"
,
'xlm-clm-enfr-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json"
,
'xlm-clm-ende-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json"
,
},
'merges_file'
:
{
'xlm-mlm-en-2048'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt"
,
'xlm-mlm-ende-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt"
,
'xlm-mlm-enfr-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt"
,
'xlm-mlm-enro-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt"
,
'xlm-mlm-tlm-xnli15-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt"
,
'xlm-mlm-xnli15-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt"
,
'xlm-clm-enfr-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt"
,
'xlm-clm-ende-1024'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt"
,
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'xlm-mlm-en-2048'
:
512
,
}
def
get_pairs
(
word
):
"""
Return set of symbol pairs in a word.
word is represented as tuple of symbols (symbols being variable-length strings)
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
def
text_standardize
(
text
):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text
=
text
.
replace
(
'—'
,
'-'
)
text
=
text
.
replace
(
'–'
,
'-'
)
text
=
text
.
replace
(
'―'
,
'-'
)
text
=
text
.
replace
(
'…'
,
'...'
)
text
=
text
.
replace
(
'´'
,
"'"
)
text
=
re
.
sub
(
r
'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)'''
,
r
' \1 '
,
text
)
text
=
re
.
sub
(
r
'\s*\n\s*'
,
'
\n
'
,
text
)
text
=
re
.
sub
(
r
'[^\S\n]+'
,
' '
,
text
)
return
text
.
strip
()
class
XLMTokenizer
(
PreTrainedTokenizer
):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
- lower case all inputs
- uses `SpaCy tokenizer <https://spacy.io/api/tokenizer/>`_ and
\
`ftfy <https://ftfy.readthedocs.io/en/latest/>`_ for pre-BPE tokenization if they are installed,
\
fallback to BERT's BasicTokenizer if not.
- argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols
\
(ex: "__classify__") to a vocabulary.
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
merges_file
,
unk_token
=
"<unk>"
,
bos_token
=
"<s>"
,
sep_token
=
"</s>"
,
pad_token
=
"<pad>"
,
cls_token
=
"</s>"
,
mask_token
=
"<special1>"
,
additional_special_tokens
=
[
"<special0>"
,
"<special1>"
,
"<special2>"
,
"<special3>"
,
"<special4>"
,
"<special5>"
,
"<special6>"
,
"<special7>"
,
"<special8>"
,
"<special9>"
],
**
kwargs
):
super
(
XLMTokenizer
,
self
).
__init__
(
unk_token
=
unk_token
,
bos_token
=
bos_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
additional_special_tokens
=
additional_special_tokens
,
**
kwargs
)
try
:
import
ftfy
import
spacy
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
fix_text
=
ftfy
.
fix_text
except
ImportError
:
logger
.
warning
(
"ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy."
)
self
.
nlp
=
BasicTokenizer
(
do_lower_case
=
True
)
self
.
fix_text
=
None
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[:
-
1
]
merges
=
[
tuple
(
merge
.
split
()[:
2
])
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
@
property
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
def
bpe
(
self
,
token
):
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
'</w>'
,)
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
+
'</w>'
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
if
word
==
'
\n
</w>'
:
word
=
'
\n
</w>'
self
.
cache
[
token
]
=
word
return
word
def
_tokenize
(
self
,
text
):
""" Tokenize a string. """
split_tokens
=
[]
if
self
.
fix_text
is
None
:
# Using BERT's BasicTokenizer
text
=
self
.
nlp
.
tokenize
(
text
)
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
' '
)])
else
:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text
=
self
.
nlp
(
text_standardize
(
self
.
fix_text
(
text
)))
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
.
text
.
lower
()).
split
(
' '
)])
return
split_tokens
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str/unicode) in an id using the vocab. """
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
):
""" Converts a sequence of tokens (string) in a single string. """
out_string
=
''
.
join
(
tokens
).
replace
(
'</w>'
,
' '
).
strip
()
return
out_string
def
save_vocabulary
(
self
,
save_directory
):
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merge_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment