Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8d2fca07
Unverified
Commit
8d2fca07
authored
Dec 12, 2022
by
Salvo Cavallaro
Committed by
GitHub
Dec 12, 2022
Browse files
Made LUKE Tokenizer independent from RoBERTa (#20720)
parent
799cea64
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
363 additions
and
24 deletions
+363
-24
src/transformers/models/luke/tokenization_luke.py
src/transformers/models/luke/tokenization_luke.py
+363
-24
No files found.
src/transformers/models/luke/tokenization_luke.py
View file @
8d2fca07
...
@@ -18,11 +18,14 @@ import itertools
...
@@ -18,11 +18,14 @@ import itertools
import
json
import
json
import
os
import
os
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
functools
import
lru_cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
...
import
RobertaTokenizer
import
regex
as
re
from
...tokenization_utils
import
PreTrainedTokenizer
from
...tokenization_utils_base
import
(
from
...tokenization_utils_base
import
(
ENCODE_KWARGS_DOCSTRING
,
ENCODE_KWARGS_DOCSTRING
,
AddedToken
,
AddedToken
,
...
@@ -147,14 +150,76 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
...
@@ -147,14 +150,76 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
"""
"""
class
LukeTokenizer
(
RobertaTokenizer
):
@
lru_cache
()
r
"""
# Copied from transformers.models.roberta.tokenization_roberta.bytes_to_unicode
Construct a LUKE tokenizer.
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs
=
(
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
)
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
# Copied from transformers.models.roberta.tokenization_roberta.get_pairs
def
get_pairs
(
word
):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
class
LukeTokenizer
(
PreTrainedTokenizer
):
"""
Constructs a LUKE tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding.
This tokenizer inherits from [`RobertaTokenizer`] which contains most of the main methods. Users should refer to
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
this superclass for more information regarding those methods. Compared to [`RobertaTokenizer`], [`LukeTokenizer`]
be encoded differently whether it is at the beginning of the sentence (without space) or not:
also creates entity sequences, namely `entity_ids`, `entity_attention_mask`, `entity_token_type_ids`, and
`entity_position_ids` to be used by the LUKE model.
```
>>> from transformers import LukeTokenizer
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base")
>>> tokenizer("Hello world")['input_ids']
[0, 31414, 232, 2]
>>> tokenizer(" Hello world")['input_ids']
[0, 20920, 232, 2]
```
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
<Tip>
When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).
</Tip>
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods. It also creates entity sequences, namely
`entity_ids`, `entity_attention_mask`, `entity_token_type_ids`, and `entity_position_ids` to be used by the LUKE
model.
Args:
Args:
vocab_file (`str`):
vocab_file (`str`):
...
@@ -177,11 +242,53 @@ class LukeTokenizer(RobertaTokenizer):
...
@@ -177,11 +242,53 @@ class LukeTokenizer(RobertaTokenizer):
entity_token_2 (`str`, *optional*, defaults to `<ent2>`):
entity_token_2 (`str`, *optional*, defaults to `<ent2>`):
The special token used to represent an entity span in a word token sequence. This token is only used when
The special token used to represent an entity span in a word token sequence. This token is only used when
`task` is set to `"entity_pair_classification"`.
`task` is set to `"entity_pair_classification"`.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the beginning of
sequence. The token used is the `cls_token`.
</Tip>
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
<Tip>
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
The token used is the `sep_token`.
</Tip>
sep_token (`str`, *optional*, defaults to `"</s>"`):
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
sequence classification or for a text and a question for question answering. It is also used as the last
token of a sequence built with special tokens.
cls_token (`str`, *optional*, defaults to `"<s>"`):
The classifier token which is used when doing sequence classification (classification of the whole sequence
instead of per-token classification). It is the first token of the sequence when built with special tokens.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
mask_token (`str`, *optional*, defaults to `"<mask>"`):
The token used for masking values. This is the token used when training this model with masked language
modeling. This is the token which the model will try to predict.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word. (LUKE tokenizer detect beginning of words by the preceding space).
"""
"""
vocab_files_names
=
VOCAB_FILES_NAMES
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -197,8 +304,66 @@ class LukeTokenizer(RobertaTokenizer):
...
@@ -197,8 +304,66 @@ class LukeTokenizer(RobertaTokenizer):
entity_pad_token
=
"[PAD]"
,
entity_pad_token
=
"[PAD]"
,
entity_mask_token
=
"[MASK]"
,
entity_mask_token
=
"[MASK]"
,
entity_mask2_token
=
"[MASK2]"
,
entity_mask2_token
=
"[MASK2]"
,
errors
=
"replace"
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
add_prefix_space
=
False
,
**
kwargs
**
kwargs
):
):
bos_token
=
AddedToken
(
bos_token
,
lstrip
=
False
,
rstrip
=
False
)
if
isinstance
(
bos_token
,
str
)
else
bos_token
eos_token
=
AddedToken
(
eos_token
,
lstrip
=
False
,
rstrip
=
False
)
if
isinstance
(
eos_token
,
str
)
else
eos_token
sep_token
=
AddedToken
(
sep_token
,
lstrip
=
False
,
rstrip
=
False
)
if
isinstance
(
sep_token
,
str
)
else
sep_token
cls_token
=
AddedToken
(
cls_token
,
lstrip
=
False
,
rstrip
=
False
)
if
isinstance
(
cls_token
,
str
)
else
cls_token
unk_token
=
AddedToken
(
unk_token
,
lstrip
=
False
,
rstrip
=
False
)
if
isinstance
(
unk_token
,
str
)
else
unk_token
pad_token
=
AddedToken
(
pad_token
,
lstrip
=
False
,
rstrip
=
False
)
if
isinstance
(
pad_token
,
str
)
else
pad_token
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
super
().
__init__
(
errors
=
errors
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
add_prefix_space
=
add_prefix_space
,
task
=
task
,
max_entity_length
=
32
,
max_mention_length
=
30
,
entity_token_1
=
"<ent>"
,
entity_token_2
=
"<ent2>"
,
entity_unk_token
=
entity_unk_token
,
entity_pad_token
=
entity_pad_token
,
entity_mask_token
=
entity_mask_token
,
entity_mask2_token
=
entity_mask2_token
,
**
kwargs
,
)
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
self
.
encoder
=
json
.
load
(
vocab_handle
)
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
with
open
(
merges_file
,
encoding
=
"utf-8"
)
as
merges_handle
:
bpe_merges
=
merges_handle
.
read
().
split
(
"
\n
"
)[
1
:
-
1
]
bpe_merges
=
[
tuple
(
merge
.
split
())
for
merge
in
bpe_merges
]
self
.
bpe_ranks
=
dict
(
zip
(
bpe_merges
,
range
(
len
(
bpe_merges
))))
self
.
cache
=
{}
self
.
add_prefix_space
=
add_prefix_space
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self
.
pat
=
re
.
compile
(
r
"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
# we add 2 special tokens for downstream tasks
# we add 2 special tokens for downstream tasks
# for more information about lstrip and rstrip, see https://github.com/huggingface/transformers/pull/2778
# for more information about lstrip and rstrip, see https://github.com/huggingface/transformers/pull/2778
entity_token_1
=
(
entity_token_1
=
(
...
@@ -214,21 +379,6 @@ class LukeTokenizer(RobertaTokenizer):
...
@@ -214,21 +379,6 @@ class LukeTokenizer(RobertaTokenizer):
kwargs
[
"additional_special_tokens"
]
=
kwargs
.
get
(
"additional_special_tokens"
,
[])
kwargs
[
"additional_special_tokens"
]
=
kwargs
.
get
(
"additional_special_tokens"
,
[])
kwargs
[
"additional_special_tokens"
]
+=
[
entity_token_1
,
entity_token_2
]
kwargs
[
"additional_special_tokens"
]
+=
[
entity_token_1
,
entity_token_2
]
super
().
__init__
(
vocab_file
=
vocab_file
,
merges_file
=
merges_file
,
task
=
task
,
max_entity_length
=
32
,
max_mention_length
=
30
,
entity_token_1
=
"<ent>"
,
entity_token_2
=
"<ent2>"
,
entity_unk_token
=
entity_unk_token
,
entity_pad_token
=
entity_pad_token
,
entity_mask_token
=
entity_mask_token
,
entity_mask2_token
=
entity_mask2_token
,
**
kwargs
,
)
with
open
(
entity_vocab_file
,
encoding
=
"utf-8"
)
as
entity_vocab_handle
:
with
open
(
entity_vocab_file
,
encoding
=
"utf-8"
)
as
entity_vocab_handle
:
self
.
entity_vocab
=
json
.
load
(
entity_vocab_handle
)
self
.
entity_vocab
=
json
.
load
(
entity_vocab_handle
)
for
entity_special_token
in
[
entity_unk_token
,
entity_pad_token
,
entity_mask_token
,
entity_mask2_token
]:
for
entity_special_token
in
[
entity_unk_token
,
entity_pad_token
,
entity_mask_token
,
entity_mask2_token
]:
...
@@ -257,6 +407,171 @@ class LukeTokenizer(RobertaTokenizer):
...
@@ -257,6 +407,171 @@ class LukeTokenizer(RobertaTokenizer):
self
.
max_mention_length
=
max_mention_length
self
.
max_mention_length
=
max_mention_length
@
property
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.vocab_size with Roberta->Luke, RoBERTa->LUKE
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_vocab with Roberta->Luke, RoBERTa->LUKE
def
get_vocab
(
self
):
return
dict
(
self
.
encoder
,
**
self
.
added_tokens_encoder
)
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.bpe with Roberta->Luke, RoBERTa->LUKE
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
"inf"
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
except
ValueError
:
new_word
.
extend
(
word
[
i
:])
break
else
:
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
" "
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._tokenize with Roberta->Luke, RoBERTa->LUKE
def
_tokenize
(
self
,
text
):
"""Tokenize a string."""
bpe_tokens
=
[]
for
token
in
re
.
findall
(
self
.
pat
,
text
):
token
=
""
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
"utf-8"
)
)
# Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens
.
extend
(
bpe_token
for
bpe_token
in
self
.
bpe
(
token
).
split
(
" "
))
return
bpe_tokens
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_token_to_id with Roberta->Luke, RoBERTa->LUKE
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer._convert_id_to_token with Roberta->Luke, RoBERTa->LUKE
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
decoder
.
get
(
index
)
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.convert_tokens_to_string with Roberta->Luke, RoBERTa->LUKE
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (string) in a single string."""
text
=
""
.
join
(
tokens
)
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
"utf-8"
,
errors
=
self
.
errors
)
return
text
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.build_inputs_with_special_tokens with Roberta->Luke, RoBERTa->LUKE
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A LUKE sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s></s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
if
token_ids_1
is
None
:
return
[
self
.
cls_token_id
]
+
token_ids_0
+
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
sep
=
[
self
.
sep_token_id
]
return
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.get_special_tokens_mask with Roberta->Luke, RoBERTa->LUKE
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
if
token_ids_1
is
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
,
1
]
+
([
0
]
*
len
(
token_ids_1
))
+
[
1
]
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.create_token_type_ids_from_sequences with Roberta->Luke, RoBERTa->LUKE
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. LUKE does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
# Copied from transformers.models.roberta.tokenization_roberta.RobertaTokenizer.prepare_for_tokenization with Roberta->Luke, RoBERTa->LUKE
def
prepare_for_tokenization
(
self
,
text
,
is_split_into_words
=
False
,
**
kwargs
):
add_prefix_space
=
kwargs
.
pop
(
"add_prefix_space"
,
self
.
add_prefix_space
)
if
(
is_split_into_words
or
add_prefix_space
)
and
(
len
(
text
)
>
0
and
not
text
[
0
].
isspace
()):
text
=
" "
+
text
return
(
text
,
kwargs
)
@
add_end_docstrings
(
ENCODE_KWARGS_DOCSTRING
,
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING
)
@
add_end_docstrings
(
ENCODE_KWARGS_DOCSTRING
,
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING
)
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -1377,7 +1692,31 @@ class LukeTokenizer(RobertaTokenizer):
...
@@ -1377,7 +1692,31 @@ class LukeTokenizer(RobertaTokenizer):
return
encoded_inputs
return
encoded_inputs
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
vocab_file
,
merge_file
=
super
().
save_vocabulary
(
save_directory
,
filename_prefix
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
merge_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"merges_file"
]
)
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
indent
=
2
,
sort_keys
=
True
,
ensure_ascii
=
False
)
+
"
\n
"
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
writer
.
write
(
"#version: 0.2
\n
"
)
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
f
"Saving vocabulary to
{
merge_file
}
: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index
=
token_index
writer
.
write
(
" "
.
join
(
bpe_tokens
)
+
"
\n
"
)
index
+=
1
entity_vocab_file
=
os
.
path
.
join
(
entity_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"entity_vocab_file"
]
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"entity_vocab_file"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment