Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2ae3be54
Unverified
Commit
2ae3be54
authored
Jan 18, 2022
by
Suraj Patil
Committed by
GitHub
Jan 18, 2022
Browse files
[MBartTokenizer] remove dep on xlm-roberta tokenizer (#15201)
parent
84c60a7b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
70 deletions
+219
-70
src/transformers/models/mbart/tokenization_mbart.py
src/transformers/models/mbart/tokenization_mbart.py
+143
-37
src/transformers/models/mbart/tokenization_mbart_fast.py
src/transformers/models/mbart/tokenization_mbart_fast.py
+76
-33
No files found.
src/transformers/models/mbart/tokenization_mbart.py
View file @
2ae3be54
...
...
@@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
...tokenization_utils
import
BatchEncoding
import
sentencepiece
as
spm
from
...tokenization_utils
import
AddedToken
,
BatchEncoding
,
PreTrainedTokenizer
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta
import
XLMRobertaTokenizer
logger
=
logging
.
get_logger
(
__name__
)
SPIECE_UNDERLINE
=
"▁"
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
...
...
@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25"
:
1024
,
}
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizer
(
XLMRobertaTokenizer
):
# fmt: off
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
]
# fmt: on
class
MBartTokenizer
(
PreTrainedTokenizer
):
"""
Construct an MBART tokenizer.
[`MBartTokenizer`] is a subclass of [`XLMRobertaTokenizer`]. Refer to superclass [`XLMRoberta
Tokenizer`]
for usage
examples and documentation concerning the initialization parameters and other methods
.
Adapted from [`RobertaTokenizer`] and [`XLNet
Tokenizer`]
. Based on
[SentencePiece](https://github.com/google/sentencepiece)
.
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
...
...
@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
prefix_tokens
:
List
[
int
]
=
[]
suffix_tokens
:
List
[
int
]
=
[]
def
__init__
(
self
,
*
args
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
additional_special_tokens
=
None
,
**
kwargs
self
,
vocab_file
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
sp_model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
additional_special_tokens
=
None
,
**
kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
self
.
sp_model_kwargs
=
{}
if
sp_model_kwargs
is
None
else
sp_model_kwargs
super
().
__init__
(
*
args
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
tokenizer_file
=
tokenizer_file
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
additional_special_tokens
=
additional_special_tokens
,
sp_model_kwargs
=
self
.
sp_model_kwargs
,
**
kwargs
,
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
Load
(
str
(
vocab_file
))
self
.
vocab_file
=
vocab_file
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# Mimic fairseq token-to-id alignment for the first 4 token
self
.
fairseq_tokens_to_ids
=
{
"<s>"
:
0
,
"<pad>"
:
1
,
"</s>"
:
2
,
"<unk>"
:
3
}
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self
.
fairseq_offset
=
1
self
.
sp_model_size
=
len
(
self
.
sp_model
)
self
.
lang_code_to_id
=
{
code
:
self
.
sp_model_size
+
i
+
self
.
fairseq_offset
for
i
,
code
in
enumerate
(
FAIRSEQ_LANGUAGE_CODES
)
...
...
@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer):
self
.
tgt_lang
=
tgt_lang
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model_proto"
]
=
self
.
sp_model
.
serialized_model_proto
()
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
# for backward compatibility
if
not
hasattr
(
self
,
"sp_model_kwargs"
):
self
.
sp_model_kwargs
=
{}
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
LoadFromSerializedProto
(
self
.
sp_model_proto
)
@
property
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
+
1
# Plus 1 for the mask token
...
...
@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
def
_build_translation_inputs
(
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
):
...
...
@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer):
inputs
[
"forced_bos_token_id"
]
=
tgt_lang_id
return
inputs
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
_tokenize
(
self
,
text
:
str
)
->
List
[
str
]:
return
self
.
sp_model
.
encode
(
text
,
out_type
=
str
)
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
spm_id
=
self
.
sp_model
.
PieceToId
(
token
)
# Need to return unknown token if the SP model returned 0
return
spm_id
+
self
.
fairseq_offset
if
spm_id
else
self
.
unk_token_id
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string
=
""
.
join
(
tokens
).
replace
(
SPIECE_UNDERLINE
,
" "
).
strip
()
return
out_string
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
def
prepare_seq2seq_batch
(
self
,
src_texts
:
List
[
str
],
...
...
src/transformers/models/mbart/tokenization_mbart_fast.py
View file @
2ae3be54
...
...
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
shutil
import
copyfile
from
typing
import
List
,
Optional
,
Tuple
from
tokenizers
import
processors
from
...file_utils
import
is_sentencepiece_available
from
...tokenization_utils
import
BatchEncoding
from
...tokenization_utils
import
AddedToken
,
BatchEncoding
from
...tokenization_utils_fast
import
PreTrainedTokenizerFast
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta_fast
import
XLMRobertaTokenizerFast
if
is_sentencepiece_available
():
...
...
@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25"
:
1024
,
}
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizerFast
(
XLMRobertaTokenizerFast
):
# fmt: off
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
]
# fmt: on
class
MBartTokenizerFast
(
PreTrainedTokenizerFast
):
"""
Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
...
...
@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
slow_tokenizer_class
=
MBartTokenizer
prefix_tokens
:
List
[
int
]
=
[]
...
...
@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
self
,
vocab_file
=
None
,
tokenizer_file
=
None
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
src_lang
=
None
,
tgt_lang
=
None
,
additional_special_tokens
=
None
,
**
kwargs
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
super
().
__init__
(
vocab_file
=
vocab_file
,
tokenizer_file
=
tokenizer_file
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
additional_special_tokens
=
additional_special_tokens
,
**
kwargs
,
)
self
.
vocab_file
=
vocab_file
self
.
can_save_slow_tokenizer
=
False
if
not
self
.
vocab_file
else
True
_additional_special_tokens
=
FAIRSEQ_LANGUAGE_CODES
.
copy
()
if
additional_special_tokens
is
not
None
:
...
...
@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
def
_build_translation_inputs
(
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
):
...
...
@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
pair
=
prefix_tokens_str
+
[
"$A"
,
"$B"
]
+
suffix_tokens_str
,
special_tokens
=
list
(
zip
(
prefix_tokens_str
+
suffix_tokens_str
,
self
.
prefix_tokens
+
self
.
suffix_tokens
)),
)
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory."
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment