Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
75d5f98f
Commit
75d5f98f
authored
Aug 09, 2019
by
LysandreJik
Browse files
Roberta tokenization + fixed tests (py3 + py2).
parent
14e970c2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
141 additions
and
227 deletions
+141
-227
pytorch_transformers/tests/modeling_roberta_test.py
pytorch_transformers/tests/modeling_roberta_test.py
+2
-38
pytorch_transformers/tests/tokenization_roberta_test.py
pytorch_transformers/tests/tokenization_roberta_test.py
+4
-7
pytorch_transformers/tokenization_roberta.py
pytorch_transformers/tokenization_roberta.py
+135
-182
No files found.
pytorch_transformers/tests/modeling_roberta_test.py
View file @
75d5f98f
...
@@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
...
@@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
return
config
,
inputs_dict
def
test_inference_masked_lm
(
self
):
model
=
RobertaForMaskedLM
.
from_pretrained
(
'roberta-base'
)
input_ids
=
torch
.
tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
output
=
model
(
input_ids
)[
0
]
expected_shape
=
torch
.
Size
((
1
,
11
,
50265
))
self
.
assertEqual
(
output
.
shape
,
expected_shape
)
# compare the actual values for a slice.
expected_slice
=
torch
.
Tensor
(
[[[
33.8843
,
-
4.3107
,
22.7779
],
[
4.6533
,
-
2.8099
,
13.6252
],
[
1.8222
,
-
3.6898
,
8.8600
]]]
)
self
.
assertTrue
(
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-3
)
)
# @pytest.mark.slow
def
test_inference_no_head
(
self
):
model
=
RobertaModel
.
from_pretrained
(
'roberta-base'
)
input_ids
=
torch
.
tensor
([[
0
,
31414
,
232
,
328
,
740
,
1140
,
12695
,
69
,
46078
,
1588
,
2
]])
output
=
model
(
input_ids
)[
0
]
# compare the actual values for a slice.
expected_slice
=
torch
.
Tensor
(
[[[
-
0.0231
,
0.0782
,
0.0074
],
[
-
0.1854
,
0.0539
,
-
0.0174
],
[
0.0548
,
0.0799
,
0.1687
]]]
)
self
.
assertTrue
(
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-3
)
)
def
setUp
(
self
):
def
setUp
(
self
):
self
.
model_tester
=
RobertaModelTest
.
RobertaModelTester
(
self
)
self
.
model_tester
=
RobertaModelTest
.
RobertaModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
RobertaConfig
,
hidden_size
=
37
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
RobertaConfig
,
hidden_size
=
37
)
...
@@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
...
@@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
class
RobertaModelIntegrationTest
(
unittest
.
TestCase
):
class
RobertaModelIntegrationTest
(
unittest
.
TestCase
):
#
@pytest.mark.slow
@
pytest
.
mark
.
slow
def
test_inference_masked_lm
(
self
):
def
test_inference_masked_lm
(
self
):
model
=
RobertaForMaskedLM
.
from_pretrained
(
'roberta-base'
)
model
=
RobertaForMaskedLM
.
from_pretrained
(
'roberta-base'
)
...
@@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase):
...
@@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase):
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-3
)
torch
.
allclose
(
output
[:,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-3
)
)
)
#
@pytest.mark.slow
@
pytest
.
mark
.
slow
def
test_inference_no_head
(
self
):
def
test_inference_no_head
(
self
):
model
=
RobertaModel
.
from_pretrained
(
'roberta-base'
)
model
=
RobertaModel
.
from_pretrained
(
'roberta-base'
)
...
...
pytorch_transformers/tests/tokenization_roberta_test.py
View file @
75d5f98f
...
@@ -18,8 +18,7 @@ import os
...
@@ -18,8 +18,7 @@ import os
import
json
import
json
import
unittest
import
unittest
from
pytorch_transformers.tokenization_roberta
import
RobertaTokenizer
,
DICT_FILES_NAMES
from
pytorch_transformers.tokenization_roberta
import
RobertaTokenizer
,
VOCAB_FILES_NAMES
from
pytorch_transformers.tokenization_gpt2
import
GPT2Tokenizer
,
VOCAB_FILES_NAMES
from
.tokenization_tests_commons
import
CommonTestCases
from
.tokenization_tests_commons
import
CommonTestCases
...
@@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
fp
.
write
(
"
\n
"
.
join
(
merges
))
fp
.
write
(
"
\n
"
.
join
(
merges
))
def
get_tokenizer
(
self
):
def
get_tokenizer
(
self
):
bpe_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
self
.
special_tokens_map
)
return
RobertaTokenizer
.
from_pretrained
(
self
.
tmpdirname
,
**
self
.
special_tokens_map
)
return
RobertaTokenizer
.
from_pretrained
(
"roberta-base"
,
bpe_tokenizer
=
bpe_tokenizer
)
def
get_input_output_texts
(
self
):
def
get_input_output_texts
(
self
):
input_text
=
u
"lower newer"
input_text
=
u
"lower newer"
...
@@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
...
@@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
return
input_text
,
output_text
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
tokenizer
=
self
.
get
_token
izer
(
)
tokenizer
=
RobertaTokenizer
(
self
.
vocab_file
,
self
.
merges_file
,
**
self
.
special
_token
s_map
)
text
=
"lower"
text
=
"lower"
bpe_tokens
=
[
"low"
,
"er"
]
bpe_tokens
=
[
"low"
,
"er"
]
tokens
=
tokenizer
.
tokenize
(
text
)
tokens
=
tokenizer
.
tokenize
(
text
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
self
.
assertListEqual
(
tokens
,
bpe_tokens
)
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_tokens
=
tokens
+
[
tokenizer
.
unk_token
]
input_bpe_tokens
=
[
0
,
4
,
12
,
176
,
2
]
input_bpe_tokens
=
[
13
,
12
,
17
]
tokenizer
.
convert_tokens_to_ids
(
input_tokens
)
self
.
assertListEqual
(
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
tokenizer
.
convert_tokens_to_ids
(
input_tokens
),
input_bpe_tokens
)
...
...
pytorch_transformers/tokenization_roberta.py
View file @
75d5f98f
...
@@ -12,229 +12,182 @@
...
@@ -12,229 +12,182 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Tokenization classes for
RoBERTa
."""
"""Tokenization classes for
OpenAI GPT
."""
from
__future__
import
(
absolute_import
,
division
,
print_function
,
from
__future__
import
(
absolute_import
,
division
,
print_function
,
unicode_literals
)
unicode_literals
)
import
sys
import
json
import
json
import
logging
import
logging
import
re
from
io
import
open
import
six
import
os
import
os
import
regex
as
re
from
io
import
open
from
.tokenization_gpt2
import
bytes_to_unicode
,
get_pairs
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_utils
import
PreTrainedTokenizer
from
.tokenization_gpt2
import
GPT2Tokenizer
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def
lru_cache
():
return
lambda
func
:
func
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
DICT_FILES_NAMES
=
{
VOCAB_FILES_NAMES
=
{
'dict_file'
:
'dict.txt'
,
'vocab_file'
:
'vocab.json'
,
'merges_file'
:
'merges.txt'
,
}
}
PRETRAINED_DICT_FILES_MAP
=
{
PRETRAINED_VOCAB_FILES_MAP
=
{
'dict_file'
:
'vocab_file'
:
{
'roberta-base'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json"
,
'roberta-large'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
,
'roberta-large-mnli'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json"
,
},
'merges_file'
:
{
{
'roberta-base'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-
dict
.txt"
,
'roberta-base'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-
merges
.txt"
,
'roberta-large'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-
base-dict
.txt"
,
'roberta-large'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-
large-merges
.txt"
,
'roberta-large-mnli'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-
base-dict
.txt"
,
'roberta-large-mnli'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-
large-mnli-merges
.txt"
,
},
},
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'roberta-base'
:
512
,
'roberta-base'
:
1024
,
'roberta-large'
:
512
,
'roberta-large'
:
1024
,
'roberta-large-mnli'
:
512
,
'roberta-large-mnli'
:
1024
,
}
}
SPACE_NORMALIZER
=
re
.
compile
(
r
"\s+"
)
def
tokenize_line
(
line
):
line
=
SPACE_NORMALIZER
.
sub
(
" "
,
line
)
line
=
line
.
strip
()
return
line
.
split
()
class
Dictionary
(
object
):
"""
A mapping from symbols to consecutive integers
From Facebook's fairseq.
"""
def
__init__
(
self
,
pad
=
'<pad>'
,
eos
=
'</s>'
,
unk
=
'<unk>'
,
bos
=
'<s>'
,
extra_special_symbols
=
None
,
):
self
.
unk_word
,
self
.
pad_word
,
self
.
eos_word
=
unk
,
pad
,
eos
self
.
symbols
=
[]
self
.
count
=
[]
self
.
indices
=
{}
self
.
bos_index
=
self
.
add_symbol
(
bos
)
self
.
pad_index
=
self
.
add_symbol
(
pad
)
self
.
eos_index
=
self
.
add_symbol
(
eos
)
self
.
unk_index
=
self
.
add_symbol
(
unk
)
if
extra_special_symbols
:
for
s
in
extra_special_symbols
:
self
.
add_symbol
(
s
)
self
.
nspecial
=
len
(
self
.
symbols
)
def
__getitem__
(
self
,
idx
):
if
idx
<
len
(
self
.
symbols
):
return
self
.
symbols
[
idx
]
return
self
.
unk_word
def
index
(
self
,
sym
):
"""Returns the index of the specified symbol"""
assert
isinstance
(
sym
,
str
)
if
sym
in
self
.
indices
:
return
self
.
indices
[
sym
]
return
self
.
unk_index
def
add_symbol
(
self
,
word
,
n
=
1
):
"""Adds a word to the dictionary"""
if
word
in
self
.
indices
:
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
n
return
idx
else
:
idx
=
len
(
self
.
symbols
)
self
.
indices
[
word
]
=
idx
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
n
)
return
idx
@
classmethod
def
load
(
cls
,
f
,
ignore_utf_errors
=
False
):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d
=
cls
()
d
.
add_from_file
(
f
,
ignore_utf_errors
)
return
d
def
add_from_file
(
self
,
f
,
ignore_utf_errors
=
False
):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if
isinstance
(
f
,
six
.
string_types
):
try
:
if
not
ignore_utf_errors
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
self
.
add_from_file
(
fd
)
else
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
,
errors
=
'ignore'
)
as
fd
:
self
.
add_from_file
(
fd
)
except
FileNotFoundError
as
fnfe
:
raise
fnfe
except
UnicodeError
:
raise
Exception
(
"Incorrect encoding detected in {}, please "
"rebuild the dataset"
.
format
(
f
))
return
lines
=
f
.
read
().
splitlines
()
for
line
in
lines
:
idx
=
line
.
rfind
(
' '
)
if
idx
==
-
1
:
raise
ValueError
(
"Incorrect dictionary format, expected '<token> <cnt>'"
)
word
=
line
[:
idx
]
count
=
int
(
line
[
idx
+
1
:])
self
.
indices
[
word
]
=
len
(
self
.
symbols
)
self
.
symbols
.
append
(
word
)
self
.
count
.
append
(
count
)
def
encode_line
(
self
,
line
,
line_tokenizer
=
tokenize_line
,
add_if_not_exist
=
True
,
consumer
=
None
,
append_eos
=
True
,
reverse_order
=
False
):
words
=
line_tokenizer
(
line
)
if
reverse_order
:
words
=
list
(
reversed
(
words
))
nwords
=
len
(
words
)
ids
=
[
0
]
*
(
nwords
+
1
if
append_eos
else
nwords
)
for
i
,
word
in
enumerate
(
words
):
if
add_if_not_exist
:
idx
=
self
.
add_symbol
(
word
)
else
:
idx
=
self
.
index
(
word
)
if
consumer
is
not
None
:
consumer
(
word
,
idx
)
ids
[
i
]
=
idx
if
append_eos
:
ids
[
nwords
]
=
self
.
eos_index
return
ids
class
RobertaTokenizer
(
PreTrainedTokenizer
):
class
RobertaTokenizer
(
PreTrainedTokenizer
):
"""
"""
RoBERTa
tokenizer. Peculiarities:
GPT-2 BPE
tokenizer. Peculiarities:
-
GPT-2 tokenizer with a different integer mapping on top.
-
Byte-level BPE
"""
"""
vocab_files_names
=
DICT
_FILES_NAMES
vocab_files_names
=
VOCAB
_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_
DICT
_FILES_MAP
pretrained_vocab_files_map
=
PRETRAINED_
VOCAB
_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
dict
_file
,
bpe_tokenizer
=
None
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
def
__init__
(
self
,
vocab
_file
,
merges_file
,
errors
=
'replace'
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
unk_token
=
"<unk>"
,
**
kwargs
):
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
**
kwargs
):
super
(
RobertaTokenizer
,
self
).
__init__
(
cl
s_token
=
bos_token
,
sep
_token
=
eos_token
,
eos
_token
=
eos
_token
,
super
(
RobertaTokenizer
,
self
).
__init__
(
bo
s_token
=
bos_token
,
eos
_token
=
eos_token
,
unk
_token
=
unk
_token
,
unk
_token
=
unk
_token
,
**
kwargs
)
sep_token
=
sep_token
,
cls
_token
=
cls
_token
,
**
kwargs
)
self
.
gpt2_tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
if
bpe_tokenizer
is
None
else
bpe_tokenizer
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
dictionary
=
Dictionary
.
load
(
dict_file
)
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
bpe_data
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_data
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
dictionary
.
indices
)
return
len
(
self
.
encoder
)
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
_tokenize
(
self
,
text
):
def
_tokenize
(
self
,
text
):
""" Use GPT-2 Tokenizer """
""" Tokenize a string. """
return
self
.
gpt2_tokenizer
.
_tokenize
(
text
)
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
if
sys
.
version_info
[
0
]
==
2
:
token
=
''
.
join
(
self
.
byte_encoder
[
ord
(
b
)]
for
b
in
token
)
else
:
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
if
self
.
dictionary
.
index
(
token
)
!=
3
:
""" Converts a token (str/unicode) in an id using the vocab. """
return
self
.
dictionary
.
index
(
token
)
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
return
self
.
dictionary
.
index
(
str
(
self
.
gpt2_tokenizer
.
convert_tokens_to_ids
(
token
)))
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
symbol
=
self
.
dictionary
[
index
]
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
try
:
return
self
.
decoder
.
get
(
index
)
idx
=
int
(
symbol
)
return
self
.
gpt2_tokenizer
.
_convert_id_to_token
(
idx
)
except
ValueError
:
return
symbol
def
convert_tokens_to_string
(
self
,
tokens
):
def
convert_tokens_to_string
(
self
,
tokens
):
return
self
.
gpt2_tokenizer
.
convert_tokens_to_string
(
tokens
)
""" Converts a sequence of tokens (string) in a single string. """
text
=
''
.
join
(
tokens
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
self
.
errors
)
return
text
def
convert_tokens_to_ids
(
self
,
tokens
,
no_sep_cls_tokens
=
False
):
def
add_special_tokens_single_sentence
(
self
,
token_ids
):
cls
=
[
self
.
_convert_token_to_id
(
self
.
cls_token
)]
return
[
self
.
_convert_token_to_id
(
self
.
cls_token
)]
+
token_ids
+
[
self
.
_convert_token_to_id
(
self
.
sep_token
)]
tokens
=
super
().
convert_tokens_to_ids
(
tokens
)
sep
=
[
self
.
_convert_token_to_id
(
self
.
sep_token
)]
return
(
cls
+
tokens
+
sep
)
if
(
isinstance
(
tokens
,
list
)
and
not
no_sep_cls_tokens
)
else
tokens
def
convert_ids_to_tokens
(
self
,
ids
,
skip_special_tokens
=
False
):
def
add_special_tokens_sentences_pair
(
self
,
*
token_ids
):
return
super
().
convert_ids_to_tokens
(
ids
,
skip_special_tokens
=
skip_special_tokens
)[
1
:
-
1
]
sep
=
[
self
.
_convert_token_to_id
(
self
.
sep_token
)]
cls
=
[
self
.
_convert_token_to_id
(
self
.
cls_token
)]
return
cls
+
token_ids
[
0
]
+
sep
+
sep
+
token_ids
[
1
]
+
sep
def
save_vocabulary
(
self
,
save_directory
):
def
save_vocabulary
(
self
,
save_directory
):
"""Save the tokenizer vocabulary and merge files to a directory."""
"""Save the tokenizer vocabulary and merge files to a directory."""
if
not
os
.
path
.
isdir
(
save_directory
):
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
logger
.
error
(
"Vocabulary path ({}) should be a directory"
.
format
(
save_directory
))
return
return
dict_file
=
os
.
path
.
join
(
save_directory
,
DICT_FILES_NAMES
[
'dict_file'
])
vocab_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
merge_file
=
os
.
path
.
join
(
save_directory
,
VOCAB_FILES_NAMES
[
'merges_file'
])
with
open
(
dict_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
for
i
in
range
(
self
.
dictionary
.
nspecial
,
len
(
self
.
dictionary
.
count
)):
with
open
(
vocab_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
f
"
{
list
(
self
.
dictionary
.
indices
.
keys
())[
i
]
}
{
self
.
dictionary
.
count
[
i
]
}
\n
"
)
f
.
write
(
json
.
dumps
(
self
.
encoder
,
ensure_ascii
=
False
))
vocab_files
=
self
.
gpt2_tokenizer
.
save_pretrained
(
save_directory
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
return
vocab_files
+
(
dict_file
,)
writer
.
write
(
u
'#version: 0.2
\n
'
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
.
format
(
merge_file
))
index
=
token_index
writer
.
write
(
' '
.
join
(
bpe_tokens
)
+
u
'
\n
'
)
index
+=
1
return
vocab_file
,
merge_file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment