Unverified Commit b8805084 authored by Mehrad Moradshahi's avatar Mehrad Moradshahi Committed by GitHub
Browse files

tokenization_marian.py: use current_spm for decoding (#10357)



* Fix Marian decoding

Tokenizer's decode and batch_decode now accepts a new argument (use_source_tokenizer) which indicates whether the source spm should be used to decode ids. This is useful for Marian models specificallly when decoding source input ids.

* Adapt docstrings
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent 8fd7eb34
......@@ -159,7 +159,7 @@ class MarianTokenizer(PreTrainedTokenizer):
return self.encoder.get(token, self.encoder[self.unk_token])
def remove_language_code(self, text: str):
"""Remove language codes like <<fr>> before sentencepiece"""
"""Remove language codes like >>fr<< before sentencepiece"""
match = self.language_code_re.match(text)
code: list = [match.group(0)] if match else []
return code, self.language_code_re.sub("", text)
......@@ -170,11 +170,61 @@ class MarianTokenizer(PreTrainedTokenizer):
return code + pieces
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the encoder."""
"""Converts an index (integer) in a token (str) using the decoder."""
return self.decoder.get(index, self.unk_token)
def batch_decode(self, sequences, **kwargs):
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
sequences (:obj:`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the ``__call__`` method.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to clean up the tokenization spaces.
use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
problems).
kwargs (additional keyword arguments, `optional`):
Will be passed to the underlying model specific decode method.
Returns:
:obj:`List[str]`: The list of decoded sentences.
"""
return super().batch_decode(sequences, **kwargs)
def decode(self, token_ids, **kwargs):
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids (:obj:`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the ``__call__`` method.
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to remove special tokens in the decoding.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to clean up the tokenization spaces.
use_source_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the source tokenizer to decode sequences (only applicable in sequence-to-sequence
problems).
kwargs (additional keyword arguments, `optional`):
Will be passed to the underlying model specific decode method.
Returns:
:obj:`str`: The decoded sentence.
"""
return super().decode(token_ids, **kwargs)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses target language sentencepiece model"""
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise """
if self._decode_use_source_tokenizer:
return self.spm_source.DecodePieces(tokens)
else:
return self.spm_target.DecodePieces(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
......
......@@ -486,6 +486,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
**kwargs
) -> str:
"""
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
......
......@@ -122,6 +122,8 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
self.added_tokens_decoder: Dict[int, str] = {}
self.unique_no_split_tokens: List[str] = []
self._decode_use_source_tokenizer = False
@property
def is_fast(self) -> bool:
return False
......@@ -702,7 +704,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
spaces_between_special_tokens: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
......
......@@ -106,6 +106,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
if slow_tokenizer is not None:
kwargs.update(slow_tokenizer.init_kwargs)
self._decode_use_source_tokenizer = False
# We call this after having initialized the backend tokenizer because we update it.
super().__init__(**kwargs)
......@@ -491,6 +493,8 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
clean_up_tokenization_spaces: bool = True,
**kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
if isinstance(token_ids, int):
token_ids = [token_ids]
text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment