Unverified Commit 2ae3be54 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[MBartTokenizer] remove dep on xlm-roberta tokenizer (#15201)

parent 84c60a7b
...@@ -13,16 +13,20 @@ ...@@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
from ...tokenization_utils import BatchEncoding import sentencepiece as spm
from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
from ...utils import logging from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"} VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
...@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25": 1024, "facebook/mbart-large-cc25": 1024,
} }
FAIRSEQ_LANGUAGE_CODES = [ # fmt: off
"ar_AR", FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN"]
"cs_CZ", # fmt: on
"de_DE",
"en_XX",
"es_XX", class MBartTokenizer(PreTrainedTokenizer):
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
]
class MBartTokenizer(XLMRobertaTokenizer):
""" """
Construct an MBART tokenizer. Construct an MBART tokenizer.
[`MBartTokenizer`] is a subclass of [`XLMRobertaTokenizer`]. Refer to superclass [`XLMRobertaTokenizer`] for usage Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
examples and documentation concerning the initialization parameters and other methods. [SentencePiece](https://github.com/google/sentencepiece).
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code> The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents. <tokens> <eos>``` for target language documents.
...@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
model_input_names = ["input_ids", "attention_mask"]
prefix_tokens: List[int] = [] prefix_tokens: List[int] = []
suffix_tokens: List[int] = [] suffix_tokens: List[int] = []
def __init__( def __init__(
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
tokenizer_file=None,
src_lang=None,
tgt_lang=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
additional_special_tokens=None,
**kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
super().__init__( super().__init__(
*args, bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
src_lang=src_lang, src_lang=src_lang,
tgt_lang=tgt_lang, tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
**kwargs, **kwargs,
) )
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(str(vocab_file))
self.vocab_file = vocab_file
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# Mimic fairseq token-to-id alignment for the first 4 token
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1
self.sp_model_size = len(self.sp_model) self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = { self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
...@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer):
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang) self.set_src_lang_special_tokens(self._src_lang)
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
...@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def _build_translation_inputs( def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
): ):
...@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer):
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
# Need to return unknown token if the SP model returned 0
return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
......
...@@ -13,15 +13,17 @@ ...@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from shutil import copyfile
from typing import List, Optional, Tuple
from tokenizers import processors from tokenizers import processors
from ...file_utils import is_sentencepiece_available from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import BatchEncoding from ...tokenization_utils import AddedToken, BatchEncoding
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import logging
from ..xlm_roberta.tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
if is_sentencepiece_available(): if is_sentencepiece_available():
...@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25": 1024, "facebook/mbart-large-cc25": 1024,
} }
FAIRSEQ_LANGUAGE_CODES = [ # fmt: off
"ar_AR", FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN"]
"cs_CZ", # fmt: on
"de_DE",
"en_XX",
"es_XX", class MBartTokenizerFast(PreTrainedTokenizerFast):
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
]
class MBartTokenizerFast(XLMRobertaTokenizerFast):
""" """
Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
...@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = MBartTokenizer slow_tokenizer_class = MBartTokenizer
prefix_tokens: List[int] = [] prefix_tokens: List[int] = []
...@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
self, self,
vocab_file=None, vocab_file=None,
tokenizer_file=None, tokenizer_file=None,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
src_lang=None, src_lang=None,
tgt_lang=None, tgt_lang=None,
additional_special_tokens=None, additional_special_tokens=None,
**kwargs **kwargs
): ):
# Mask token behave like a normal word, i.e. include the space before it
mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
super().__init__( super().__init__(
vocab_file=vocab_file, vocab_file=vocab_file,
tokenizer_file=tokenizer_file, tokenizer_file=tokenizer_file,
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
src_lang=src_lang, src_lang=src_lang,
tgt_lang=tgt_lang, tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
**kwargs, **kwargs,
) )
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy() _additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
if additional_special_tokens is not None: if additional_special_tokens is not None:
...@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def _build_translation_inputs( def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
): ):
...@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str, pair=prefix_tokens_str + ["$A", "$B"] + suffix_tokens_str,
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
) )
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment