Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2b83d54
Unverified
Commit
c2b83d54
authored
Oct 07, 2022
by
harry7337
Committed by
GitHub
Oct 07, 2022
Browse files
Removed Bert and XML Dependency from Herbert (#19410)
Co-authored-by:
harry7337
<
hari.8jan@gmail.com
>
parent
e6fc2016
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
554 additions
and
8 deletions
+554
-8
src/transformers/models/herbert/tokenization_herbert.py
src/transformers/models/herbert/tokenization_herbert.py
+554
-8
No files found.
src/transformers/models/herbert/tokenization_herbert.py
View file @
c2b83d54
...
@@ -12,10 +12,14 @@
...
@@ -12,10 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
os
import
re
import
unicodedata
from
typing
import
List
,
Optional
,
Tuple
from
...tokenization_utils
import
PreTrainedTokenizer
,
_is_control
,
_is_punctuation
,
_is_whitespace
from
...utils
import
logging
from
...utils
import
logging
from
..bert.tokenization_bert
import
BasicTokenizer
from
..xlm.tokenization_xlm
import
XLMTokenizer
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -38,7 +42,239 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"allegro/herbert-base-cased": 514}
...
@@ -38,7 +42,239 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"allegro/herbert-base-cased": 514}
PRETRAINED_INIT_CONFIGURATION
=
{}
PRETRAINED_INIT_CONFIGURATION
=
{}
class
HerbertTokenizer
(
XLMTokenizer
):
# Copied from transformers.models.xlm.tokenization_xlm.get_pairs
def
get_pairs
(
word
):
"""
Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
strings)
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
# Copied from transformers.models.xlm.tokenization_xlm.replace_unicode_punct
def
replace_unicode_punct
(
text
):
"""
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
"""
text
=
text
.
replace
(
","
,
","
)
text
=
re
.
sub
(
r
"。\s*"
,
". "
,
text
)
text
=
text
.
replace
(
"、"
,
","
)
text
=
text
.
replace
(
"”"
,
'"'
)
text
=
text
.
replace
(
"“"
,
'"'
)
text
=
text
.
replace
(
"∶"
,
":"
)
text
=
text
.
replace
(
":"
,
":"
)
text
=
text
.
replace
(
"?"
,
"?"
)
text
=
text
.
replace
(
"《"
,
'"'
)
text
=
text
.
replace
(
"》"
,
'"'
)
text
=
text
.
replace
(
")"
,
")"
)
text
=
text
.
replace
(
"!"
,
"!"
)
text
=
text
.
replace
(
"("
,
"("
)
text
=
text
.
replace
(
";"
,
";"
)
text
=
text
.
replace
(
"1"
,
"1"
)
text
=
text
.
replace
(
"」"
,
'"'
)
text
=
text
.
replace
(
"「"
,
'"'
)
text
=
text
.
replace
(
"0"
,
"0"
)
text
=
text
.
replace
(
"3"
,
"3"
)
text
=
text
.
replace
(
"2"
,
"2"
)
text
=
text
.
replace
(
"5"
,
"5"
)
text
=
text
.
replace
(
"6"
,
"6"
)
text
=
text
.
replace
(
"9"
,
"9"
)
text
=
text
.
replace
(
"7"
,
"7"
)
text
=
text
.
replace
(
"8"
,
"8"
)
text
=
text
.
replace
(
"4"
,
"4"
)
text
=
re
.
sub
(
r
".\s*"
,
". "
,
text
)
text
=
text
.
replace
(
"~"
,
"~"
)
text
=
text
.
replace
(
"’"
,
"'"
)
text
=
text
.
replace
(
"…"
,
"..."
)
text
=
text
.
replace
(
"━"
,
"-"
)
text
=
text
.
replace
(
"〈"
,
"<"
)
text
=
text
.
replace
(
"〉"
,
">"
)
text
=
text
.
replace
(
"【"
,
"["
)
text
=
text
.
replace
(
"】"
,
"]"
)
text
=
text
.
replace
(
"%"
,
"%"
)
return
text
# Copied from transformers.models.xlm.tokenization_xlm.remove_non_printing_char
def
remove_non_printing_char
(
text
):
"""
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
"""
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
class
BasicTokenizer
(
object
):
"""
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
Args:
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
never_split (`Iterable`, *optional*):
Collection of tokens which will never be split during tokenization. Only has an effect when
`do_basic_tokenize=True`
tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
Whether or not to tokenize Chinese characters.
This should likely be deactivated for Japanese (see this
[issue](https://github.com/huggingface/transformers/issues/328)).
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original BERT).
"""
def
__init__
(
self
,
do_lower_case
=
True
,
never_split
=
None
,
tokenize_chinese_chars
=
True
,
strip_accents
=
None
):
if
never_split
is
None
:
never_split
=
[]
self
.
do_lower_case
=
do_lower_case
self
.
never_split
=
set
(
never_split
)
self
.
tokenize_chinese_chars
=
tokenize_chinese_chars
self
.
strip_accents
=
strip_accents
def
tokenize
(
self
,
text
,
never_split
=
None
):
"""
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
WordPieceTokenizer.
Args:
never_split (`List[str]`, *optional*)
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split
=
self
.
never_split
.
union
(
set
(
never_split
))
if
never_split
else
self
.
never_split
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
if
self
.
tokenize_chinese_chars
:
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
token
not
in
never_split
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
if
self
.
strip_accents
is
not
False
:
token
=
self
.
_run_strip_accents
(
token
)
elif
self
.
strip_accents
:
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
,
never_split
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
,
never_split
=
None
):
"""Splits punctuation on a piece of text."""
if
never_split
is
not
None
and
text
in
never_split
:
return
[
text
]
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
(
(
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
#
or
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
#
or
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
#
or
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
#
or
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
#
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)
#
):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xFFFD
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
HerbertTokenizer
(
PreTrainedTokenizer
):
"""
"""
Construct a BPE tokenizer for HerBERT.
Construct a BPE tokenizer for HerBERT.
...
@@ -68,22 +304,74 @@ class HerbertTokenizer(XLMTokenizer):
...
@@ -68,22 +304,74 @@ class HerbertTokenizer(XLMTokenizer):
pad_token
=
"<pad>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
mask_token
=
"<mask>"
,
sep_token
=
"</s>"
,
sep_token
=
"</s>"
,
bos_token
=
"<s>"
,
do_lowercase_and_remove_accent
=
False
,
do_lowercase_and_remove_accent
=
False
,
additional_special_tokens
=
[
"<special0>"
,
"<special1>"
,
"<special2>"
,
"<special3>"
,
"<special4>"
,
"<special5>"
,
"<special6>"
,
"<special7>"
,
"<special8>"
,
"<special9>"
,
],
lang2id
=
None
,
id2lang
=
None
,
**
kwargs
**
kwargs
):
):
super
().
__init__
(
super
().
__init__
(
vocab_file
,
merges_file
,
tokenizer_file
=
None
,
cls_token
=
cls_token
,
unk_token
=
unk_token
,
unk_token
=
unk_token
,
bos_token
=
bos_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
mask_token
=
mask_token
,
sep_token
=
sep_token
,
additional_special_tokens
=
additional_special_tokens
,
lang2id
=
lang2id
,
id2lang
=
id2lang
,
do_lowercase_and_remove_accent
=
do_lowercase_and_remove_accent
,
do_lowercase_and_remove_accent
=
do_lowercase_and_remove_accent
,
tokenizer_file
=
None
,
**
kwargs
,
**
kwargs
,
)
)
try
:
import
sacremoses
except
ImportError
:
raise
ImportError
(
"You need to install sacremoses to use HerbertTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)
self
.
sm
=
sacremoses
# cache of sm.MosesPunctNormalizer instance
self
.
cache_moses_punct_normalizer
=
dict
()
# cache of sm.MosesTokenizer instance
self
.
cache_moses_tokenizer
=
dict
()
self
.
lang_with_custom_tokenizer
=
set
([
"zh"
,
"th"
,
"ja"
])
# True for current supported model (v1.2.0), False for XLM-17 & 100
self
.
do_lowercase_and_remove_accent
=
do_lowercase_and_remove_accent
self
.
lang2id
=
lang2id
self
.
id2lang
=
id2lang
if
lang2id
is
not
None
and
id2lang
is
not
None
:
assert
len
(
lang2id
)
==
len
(
id2lang
)
self
.
ja_word_tokenizer
=
None
self
.
zh_word_tokenizer
=
None
with
open
(
vocab_file
,
encoding
=
"utf-8"
)
as
vocab_handle
:
self
.
encoder
=
json
.
load
(
vocab_handle
)
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
with
open
(
merges_file
,
encoding
=
"utf-8"
)
as
merges_handle
:
merges
=
merges_handle
.
read
().
split
(
"
\n
"
)[:
-
1
]
merges
=
[
tuple
(
merge
.
split
()[:
2
])
for
merge
in
merges
]
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{}
self
.
bert_pre_tokenizer
=
BasicTokenizer
(
self
.
bert_pre_tokenizer
=
BasicTokenizer
(
do_lower_case
=
False
,
do_lower_case
=
False
,
never_split
=
self
.
all_special_tokens
,
never_split
=
self
.
all_special_tokens
,
...
@@ -91,6 +379,112 @@ class HerbertTokenizer(XLMTokenizer):
...
@@ -91,6 +379,112 @@ class HerbertTokenizer(XLMTokenizer):
strip_accents
=
False
,
strip_accents
=
False
,
)
)
@
property
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.do_lower_case
def
do_lower_case
(
self
):
return
self
.
do_lowercase_and_remove_accent
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_punct_norm
def
moses_punct_norm
(
self
,
text
,
lang
):
if
lang
not
in
self
.
cache_moses_punct_normalizer
:
punct_normalizer
=
self
.
sm
.
MosesPunctNormalizer
(
lang
=
lang
)
self
.
cache_moses_punct_normalizer
[
lang
]
=
punct_normalizer
else
:
punct_normalizer
=
self
.
cache_moses_punct_normalizer
[
lang
]
return
punct_normalizer
.
normalize
(
text
)
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_tokenize
def
moses_tokenize
(
self
,
text
,
lang
):
if
lang
not
in
self
.
cache_moses_tokenizer
:
moses_tokenizer
=
self
.
sm
.
MosesTokenizer
(
lang
=
lang
)
self
.
cache_moses_tokenizer
[
lang
]
=
moses_tokenizer
else
:
moses_tokenizer
=
self
.
cache_moses_tokenizer
[
lang
]
return
moses_tokenizer
.
tokenize
(
text
,
return_str
=
False
,
escape
=
False
)
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.moses_pipeline
def
moses_pipeline
(
self
,
text
,
lang
):
text
=
replace_unicode_punct
(
text
)
text
=
self
.
moses_punct_norm
(
text
,
lang
)
text
=
remove_non_printing_char
(
text
)
return
text
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.ja_tokenize
def
ja_tokenize
(
self
,
text
):
if
self
.
ja_word_tokenizer
is
None
:
try
:
import
Mykytea
self
.
ja_word_tokenizer
=
Mykytea
.
Mykytea
(
f
"-model
{
os
.
path
.
expanduser
(
'~'
)
}
/local/share/kytea/model.bin"
)
except
(
AttributeError
,
ImportError
):
logger
.
error
(
"Make sure you install KyTea (https://github.com/neubig/kytea) and it's python wrapper"
" (https://github.com/chezou/Mykytea-python) with the following steps"
)
logger
.
error
(
"1. git clone git@github.com:neubig/kytea.git && cd kytea"
)
logger
.
error
(
"2. autoreconf -i"
)
logger
.
error
(
"3. ./configure --prefix=$HOME/local"
)
logger
.
error
(
"4. make && make install"
)
logger
.
error
(
"5. pip install kytea"
)
raise
return
list
(
self
.
ja_word_tokenizer
.
getWS
(
text
))
@
property
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.vocab_size
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_vocab
def
get_vocab
(
self
):
return
dict
(
self
.
encoder
,
**
self
.
added_tokens_encoder
)
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.bpe
def
bpe
(
self
,
token
):
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
"</w>"
,)
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
+
"</w>"
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
"inf"
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
except
ValueError
:
new_word
.
extend
(
word
[
i
:])
break
else
:
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
" "
.
join
(
word
)
if
word
==
"
\n
</w>"
:
word
=
"
\n
</w>"
self
.
cache
[
token
]
=
word
return
word
def
_tokenize
(
self
,
text
):
def
_tokenize
(
self
,
text
):
pre_tokens
=
self
.
bert_pre_tokenizer
.
tokenize
(
text
)
pre_tokens
=
self
.
bert_pre_tokenizer
.
tokenize
(
text
)
...
@@ -101,3 +495,155 @@ class HerbertTokenizer(XLMTokenizer):
...
@@ -101,3 +495,155 @@ class HerbertTokenizer(XLMTokenizer):
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
" "
)])
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
" "
)])
return
split_tokens
return
split_tokens
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_token_to_id
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
return
self
.
encoder
.
get
(
token
,
self
.
encoder
.
get
(
self
.
unk_token
))
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer._convert_id_to_token
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
return
self
.
decoder
.
get
(
index
,
self
.
unk_token
)
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.convert_tokens_to_string
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (string) in a single string."""
out_string
=
""
.
join
(
tokens
).
replace
(
"</w>"
,
" "
).
strip
()
return
out_string
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.build_inputs_with_special_tokens
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. An XLM sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
bos
=
[
self
.
bos_token_id
]
sep
=
[
self
.
sep_token_id
]
if
token_ids_1
is
None
:
return
bos
+
token_ids_0
+
sep
return
bos
+
token_ids_0
+
sep
+
token_ids_1
+
sep
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.get_special_tokens_mask
def
get_special_tokens_mask
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
,
already_has_special_tokens
:
bool
=
False
)
->
List
[
int
]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if
already_has_special_tokens
:
return
super
().
get_special_tokens_mask
(
token_ids_0
=
token_ids_0
,
token_ids_1
=
token_ids_1
,
already_has_special_tokens
=
True
)
if
token_ids_1
is
not
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
+
([
0
]
*
len
(
token_ids_1
))
+
[
1
]
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.create_token_type_ids_from_sequences
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLM sequence
pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
+
len
(
token_ids_1
+
sep
)
*
[
1
]
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.save_vocabulary
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
merge_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"merges_file"
]
)
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json
.
dumps
(
self
.
encoder
,
indent
=
2
,
sort_keys
=
True
,
ensure_ascii
=
False
)
+
"
\n
"
)
index
=
0
with
open
(
merge_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
bpe_tokens
,
token_index
in
sorted
(
self
.
bpe_ranks
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
f
"Saving vocabulary to
{
merge_file
}
: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index
=
token_index
writer
.
write
(
" "
.
join
(
bpe_tokens
)
+
"
\n
"
)
index
+=
1
return
vocab_file
,
merge_file
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__getstate__
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sm"
]
=
None
return
state
# Copied from transformers.models.xlm.tokenization_xlm.XLMTokenizer.__setstate__
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
try
:
import
sacremoses
except
ImportError
:
raise
ImportError
(
"You need to install sacremoses to use XLMTokenizer. "
"See https://pypi.org/project/sacremoses/ for installation."
)
self
.
sm
=
sacremoses
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment