Unverified Commit be1520d3 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

rename prepare_translation_batch -> prepare_seq2seq_batch (#6103)

parent 66fa8cea
...@@ -53,7 +53,7 @@ MBartTokenizer ...@@ -53,7 +53,7 @@ MBartTokenizer
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartTokenizer .. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_translation_batch :members: build_inputs_with_special_tokens, prepare_seq2seq_batch
......
...@@ -48,7 +48,7 @@ Example of translating english to many romance languages, using language codes: ...@@ -48,7 +48,7 @@ Example of translating english to many romance languages, using language codes:
tokenizer = MarianTokenizer.from_pretrained(model_name) tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes) print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_translation_batch(src_text)) translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français", # ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o português.', # 'Isto deve ir para o português.',
...@@ -86,6 +86,14 @@ Code to see available pretrained models: ...@@ -86,6 +86,14 @@ Code to see available pretrained models:
suffix = [x.split('/')[1] for x in model_ids] suffix = [x.split('/')[1] for x in model_ids]
multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()] multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()]
MarianMTModel
~~~~~~~~~~~~~
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
Model API is identical to BartForConditionalGeneration.
Available models are listed at `Model List <https://huggingface.co/models?search=Helsinki-NLP>`__
This class inherits nearly all functionality from ``BartForConditionalGeneration``, see that page for method signatures.
MarianConfig MarianConfig
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianConfig .. autoclass:: transformers.MarianConfig
...@@ -96,16 +104,8 @@ MarianTokenizer ...@@ -96,16 +104,8 @@ MarianTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianTokenizer .. autoclass:: transformers.MarianTokenizer
:members: prepare_translation_batch :members: prepare_seq2seq_batch
MarianMTModel
~~~~~~~~~~~~~
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
Model API is identical to BartForConditionalGeneration.
Available models are listed at `Model List <https://huggingface.co/models?search=Helsinki-NLP>`__
This class inherits all functionality from ``BartForConditionalGeneration``, see that page for method signatures.
.. autoclass:: transformers.MarianMTModel
:members:
...@@ -63,7 +63,7 @@ Summarization Tips: ...@@ -63,7 +63,7 @@ Summarization Tips:
(It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods). (It rarely makes sense to start from `bart-large` unless you are a researching finetuning methods).
**Update 2018-07-18** **Update 2018-07-18**
Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_translation_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.** Datasets: `Seq2SeqDataset` should be used for all tokenizers without a `prepare_seq2seq_batch` method. For those who do (like Marian, MBart), `TranslationDataset` should be used.**
A new dataset is needed to support multilingual tasks. A new dataset is needed to support multilingual tasks.
......
...@@ -145,7 +145,7 @@ class Seq2SeqDataset(Dataset): ...@@ -145,7 +145,7 @@ class Seq2SeqDataset(Dataset):
class TranslationDataset(Seq2SeqDataset): class TranslationDataset(Seq2SeqDataset):
"""A dataset that calls prepare_translation_batch.""" """A dataset that calls prepare_seq2seq_batch."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -167,7 +167,7 @@ class TranslationDataset(Seq2SeqDataset): ...@@ -167,7 +167,7 @@ class TranslationDataset(Seq2SeqDataset):
} }
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_translation_batch( batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch], [x["src_texts"] for x in batch],
src_lang=self.src_lang, src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch], tgt_texts=[x["tgt_texts"] for x in batch],
......
...@@ -40,7 +40,7 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -40,7 +40,7 @@ class MarianMTModel(BartForConditionalGeneration):
>>> model = MarianMTModel.from_pretrained(mname) >>> model = MarianMTModel.from_pretrained(mname)
>>> tok = MarianTokenizer.from_pretrained(mname) >>> tok = MarianTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_translation_batch(src_texts=[sample_text]) # don't need tgt_text for inference >>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text]) # don't need tgt_text for inference
>>> gen = model.generate(**batch) # for forward pass: model(**batch) >>> gen = model.generate(**batch) # for forward pass: model(**batch)
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the the bus stop ?" >>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the the bus stop ?"
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from .file_utils import add_start_docstrings_to_callable
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
...@@ -89,7 +91,7 @@ FAIRSEQ_LANGUAGE_CODES = [ ...@@ -89,7 +91,7 @@ FAIRSEQ_LANGUAGE_CODES = [
class MBartTokenizer(XLMRobertaTokenizer): class MBartTokenizer(XLMRobertaTokenizer):
""" """
This inherits from XLMRobertaTokenizer. ``prepare_translation_batch`` should be used to encode inputs. This inherits from XLMRobertaTokenizer. ``prepare_seq2seq_batch`` should be used to encode inputs.
Other tokenizer methods like ``encode`` do not work properly. Other tokenizer methods like ``encode`` do not work properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and
``<language code> <tokens> <eos>``` for target language documents. ``<language code> <tokens> <eos>``` for target language documents.
...@@ -100,7 +102,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -100,7 +102,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_translation_batch( >>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian
... ) ... )
...@@ -187,7 +189,8 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -187,7 +189,8 @@ class MBartTokenizer(XLMRobertaTokenizer):
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def prepare_translation_batch( @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
src_lang: str = "en_XX", src_lang: str = "en_XX",
...@@ -195,22 +198,73 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -195,22 +198,73 @@ class MBartTokenizer(XLMRobertaTokenizer):
tgt_lang: str = "ro_RO", tgt_lang: str = "ro_RO",
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
truncation: bool = True,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "pt", return_tensors: str = "pt",
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel. """Prepare a batch that can be passed directly to an instance of MBartModel.
Arguments: Arguments:
src_texts: list of src language texts src_texts: (:obj:`list`):
src_lang: default en_XX (english), the language we are translating from list of documents to summarize or source language texts
tgt_texts: list of tgt language texts src_lang: (:obj:`str`, `optional`, default='en_XX'):
tgt_lang: default ro_RO (romanian), the language we are translating to default en_XX (english), the language we are translating from
max_length: (default=None, which defers to the config value of 1024 for facebook/mbart-large* tgt_texts: (:obj:`list`, `optional`):
padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. list of tgt language texts or summaries.
**kwargs: passed to self.__call__ tgt_lang: (:obj:`str`, `optional`, default='ro_RO'):
default ro_RO (romanian), the language we are translating to
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
Returns:
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
""" """
if max_length is None: if max_length is None:
max_length = self.max_len max_length = self.max_len
...@@ -221,7 +275,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -221,7 +275,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
return_tensors=return_tensors, return_tensors=return_tensors,
max_length=max_length, max_length=max_length,
padding=padding, padding=padding,
truncation=True, truncation=truncation,
**kwargs, **kwargs,
) )
if tgt_texts is None: if tgt_texts is None:
......
...@@ -7,7 +7,9 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -7,7 +7,9 @@ from typing import Dict, List, Optional, Tuple, Union
import sentencepiece import sentencepiece
from .file_utils import add_start_docstrings_to_callable
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
vocab_files_names = { vocab_files_names = {
...@@ -21,7 +23,8 @@ vocab_files_names = { ...@@ -21,7 +23,8 @@ vocab_files_names = {
class MarianTokenizer(PreTrainedTokenizer): class MarianTokenizer(PreTrainedTokenizer):
"""Sentencepiece tokenizer for marian. Source and target languages have different SPM models. """Sentencepiece tokenizer for marian. Source and target languages have different SPM models.
The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a vocab dictionary. The logic is use the relevant source_spm or target_spm to encode txt as pieces, then look up each piece in a
vocab dictionary.
Examples:: Examples::
...@@ -29,7 +32,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -29,7 +32,7 @@ class MarianTokenizer(PreTrainedTokenizer):
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
>>> batch_enc: BatchEncoding = tok.prepare_translation_batch(src_texts, tgt_texts=tgt_texts) >>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts)
>>> # keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]. >>> # keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask].
>>> # model(**batch) should work >>> # model(**batch) should work
""" """
...@@ -122,30 +125,20 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -122,30 +125,20 @@ class MarianTokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id] return token_ids_0 + token_ids_1 + [self.eos_token_id]
def prepare_translation_batch( @add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
pad_to_max_length: bool = True,
return_tensors: str = "pt", return_tensors: str = "pt",
truncation_strategy="only_first", truncation=True,
padding="longest", padding="longest",
**unused, **unused,
) -> BatchEncoding: ) -> BatchEncoding:
"""Prepare model inputs for translation. For best performance, translate one sentence at a time. """Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
src_texts: list of src language texts
tgt_texts: list of tgt language texts
max_length: (None) defer to config (1024 for mbart-large-en-ro)
pad_to_max_length: (bool)
return_tensors: (str) default "pt" returns pytorch tensors, pass None to return lists.
Returns:
BatchEncoding: with keys [input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
If no tgt_text is specified, the only keys will be input_ids and attention_mask.
""" """
if "" in src_texts: if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}") raise ValueError(f"found empty string in src_texts: {src_texts}")
...@@ -155,14 +148,15 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -155,14 +148,15 @@ class MarianTokenizer(PreTrainedTokenizer):
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
max_length=max_length, max_length=max_length,
pad_to_max_length=pad_to_max_length, truncation=truncation,
truncation_strategy=truncation_strategy,
padding=padding, padding=padding,
) )
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs) model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None: if tgt_texts is None:
return model_inputs return model_inputs
if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length
if max_target_length is not None: if max_target_length is not None:
tokenizer_kwargs["max_length"] = max_target_length tokenizer_kwargs["max_length"] = max_target_length
......
...@@ -16,7 +16,8 @@ from typing import Dict, List, Optional ...@@ -16,7 +16,8 @@ from typing import Dict, List, Optional
from transformers.tokenization_reformer import ReformerTokenizer from transformers.tokenization_reformer import ReformerTokenizer
from .tokenization_utils_base import BatchEncoding from .file_utils import add_start_docstrings_to_callable
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
class PegasusTokenizer(ReformerTokenizer): class PegasusTokenizer(ReformerTokenizer):
...@@ -103,6 +104,7 @@ class PegasusTokenizer(ReformerTokenizer): ...@@ -103,6 +104,7 @@ class PegasusTokenizer(ReformerTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id] return token_ids_0 + token_ids_1 + [self.eos_token_id]
@add_start_docstrings_to_callable(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
...@@ -116,62 +118,6 @@ class PegasusTokenizer(ReformerTokenizer): ...@@ -116,62 +118,6 @@ class PegasusTokenizer(ReformerTokenizer):
""" """
Prepare model inputs for summarization or translation. Prepare model inputs for summarization or translation.
Arguments:
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
""" """
if "" in src_texts: if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}") raise ValueError(f"found empty string in src_texts: {src_texts}")
......
...@@ -1249,6 +1249,67 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1249,6 +1249,67 @@ INIT_TOKENIZER_DOCSTRING = r"""
""" """
PREPARE_SEQ2SEQ_BATCH_DOCSTRING = """
Arguments:
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
@add_end_docstrings(INIT_TOKENIZER_DOCSTRING) @add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
class PreTrainedTokenizerBase(SpecialTokensMixin): class PreTrainedTokenizerBase(SpecialTokensMixin):
""" """
......
...@@ -97,7 +97,7 @@ class MarianIntegrationTest(unittest.TestCase): ...@@ -97,7 +97,7 @@ class MarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to( model_inputs = self.tokenizer.prepare_seq2seq_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
torch_device torch_device
) )
self.assertEqual(self.model.device, model_inputs.input_ids.device) self.assertEqual(self.model.device, model_inputs.input_ids.device)
...@@ -114,7 +114,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -114,7 +114,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected_ids = [38, 121, 14, 697, 38848, 0] expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device) model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
desired_keys = { desired_keys = {
...@@ -131,12 +131,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -131,12 +131,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
def test_unk_support(self): def test_unk_support(self):
t = self.tokenizer t = self.tokenizer
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist() ids = t.prepare_seq2seq_batch(["||"]).to(torch_device).input_ids[0].tolist()
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id] expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
self.assertEqual(expected, ids) self.assertEqual(expected, ids)
def test_pad_not_split(self): def test_pad_not_split(self):
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"]).input_ids[0].tolist() input_ids_w_pad = self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
self.assertListEqual(expected_w_pad, input_ids_w_pad) self.assertListEqual(expected_w_pad, input_ids_w_pad)
...@@ -229,7 +229,7 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): ...@@ -229,7 +229,7 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
normalized = self.tokenizer.normalize("") normalized = self.tokenizer.normalize("")
self.assertIsInstance(normalized, str) self.assertIsInstance(normalized, str)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.tokenizer.prepare_translation_batch([""]) self.tokenizer.prepare_seq2seq_batch([""])
def test_pipeline(self): def test_pipeline(self):
device = 0 if torch_device == "cuda" else -1 device = 0 if torch_device == "cuda" else -1
......
...@@ -82,7 +82,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -82,7 +82,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow @slow
def test_enro_generate(self): def test_enro_generate(self):
batch: BatchEncoding = self.tokenizer.prepare_translation_batch(self.src_text).to(torch_device) batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
translated_tokens = self.model.generate(**batch) translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0]) self.assertEqual(self.tgt_text[0], decoded[0])
...@@ -134,7 +134,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -134,7 +134,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@unittest.skip("This test is broken, still generates english") @unittest.skip("This test is broken, still generates english")
def test_cc25_generate(self): def test_cc25_generate(self):
inputs = self.tokenizer.prepare_translation_batch([self.src_text[0]]).to(torch_device) inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]]).to(torch_device)
translated_tokens = self.model.generate( translated_tokens = self.model.generate(
input_ids=inputs["input_ids"].to(torch_device), input_ids=inputs["input_ids"].to(torch_device),
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"], decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
...@@ -144,7 +144,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -144,7 +144,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow @slow
def test_fill_mask(self): def test_fill_mask(self):
inputs = self.tokenizer.prepare_translation_batch(["One of the best <mask> I ever read!"]).to(torch_device) inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"]).to(torch_device)
outputs = self.model.generate( outputs = self.model.generate(
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
) )
......
...@@ -1522,3 +1522,37 @@ class TokenizerTesterMixin: ...@@ -1522,3 +1522,37 @@ class TokenizerTesterMixin:
if batch_encoded_sequence_fast is None: if batch_encoded_sequence_fast is None:
raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)") raise ValueError("Cannot convert list to numpy tensor on batch_encode_plus() (fast)")
@require_torch
def test_prepare_seq2seq_batch(self):
tokenizer = self.get_tokenizer()
if not hasattr(tokenizer, "prepare_seq2seq_batch"):
return
# Longer text that will definitely require truncation.
src_text = [
" UN Chief Says There Is No Military Solution in Syria",
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons will only worsen the violence and misery for millions of people.",
]
tgt_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
"Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei "
'pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu '
"vor face decât să înrăutăţească violenţele şi mizeria pentru milioane de oameni.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, tgt_texts=tgt_text, max_length=3, max_target_length=10, return_tensors="pt"
)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3)
batch_encoder_only = tokenizer.prepare_seq2seq_batch(
src_texts=src_text, max_length=3, max_target_length=10, return_tensors="pt"
)
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
self.assertNotIn("decoder_input_ids", batch_encoder_only)
...@@ -64,7 +64,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -64,7 +64,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_tokenizer_equivalence_en_de(self): def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None) batch = en_de_tokenizer.prepare_seq2seq_batch(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0] expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0]) self.assertListEqual(expected, batch.input_ids[0])
...@@ -78,16 +78,12 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -78,16 +78,12 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_outputs_not_longer_than_maxlen(self): def test_outputs_not_longer_than_maxlen(self):
tok = self.get_tokenizer() tok = self.get_tokenizer()
batch = tok.prepare_translation_batch( batch = tok.prepare_seq2seq_batch(["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK)
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512)) self.assertEqual(batch.input_ids.shape, (2, 512))
def test_outputs_can_be_shorter(self): def test_outputs_can_be_shorter(self):
tok = self.get_tokenizer() tok = self.get_tokenizer()
batch_smaller = tok.prepare_translation_batch( batch_smaller = tok.prepare_seq2seq_batch(["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK)
["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK
)
self.assertIsInstance(batch_smaller, BatchEncoding) self.assertIsInstance(batch_smaller, BatchEncoding)
self.assertEqual(batch_smaller.input_ids.shape, (2, 10)) self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
...@@ -123,8 +123,8 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -123,8 +123,8 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004) self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["en_EN"], 250004)
self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020) self.assertEqual(self.tokenizer.fairseq_tokens_to_ids["ro_RO"], 250020)
def test_enro_tokenizer_prepare_translation_batch(self): def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_translation_batch( batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens), self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
) )
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
...@@ -140,13 +140,13 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -140,13 +140,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
def test_max_target_length(self): def test_max_target_length(self):
batch = self.tokenizer.prepare_translation_batch( batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10 self.src_text, tgt_texts=self.tgt_text, max_length=3, max_target_length=10
) )
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10) self.assertEqual(batch.decoder_input_ids.shape[1], 10)
# max_target_length will default to max_length if not specified # max_target_length will default to max_length if not specified
batch = self.tokenizer.prepare_translation_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3) batch = self.tokenizer.prepare_seq2seq_batch(self.src_text, tgt_texts=self.tgt_text, max_length=3)
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 3) self.assertEqual(batch.decoder_input_ids.shape[1], 3)
...@@ -166,7 +166,7 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -166,7 +166,7 @@ class MBartEnroIntegrationTest(unittest.TestCase):
src_text = ["this is gunna be a long sentence " * 20] src_text = ["this is gunna be a long sentence " * 20]
assert isinstance(src_text[0], str) assert isinstance(src_text[0], str)
desired_max_length = 10 desired_max_length = 10
ids = self.tokenizer.prepare_translation_batch( ids = self.tokenizer.prepare_seq2seq_batch(
src_text, return_tensors=None, max_length=desired_max_length src_text, return_tensors=None, max_length=desired_max_length
).input_ids[0] ).input_ids[0]
self.assertEqual(ids[-2], 2) self.assertEqual(ids[-2], 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment