Unverified Commit 9e147d31 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deprecate prepare_seq2seq_batch (#10287)



* Deprecate prepare_seq2seq_batch

* Fix last tests

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* More review comments
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent e73a3e18
...@@ -56,7 +56,7 @@ FSMTTokenizer ...@@ -56,7 +56,7 @@ FSMTTokenizer
.. autoclass:: transformers.FSMTTokenizer .. autoclass:: transformers.FSMTTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask, :members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary create_token_type_ids_from_sequences, save_vocabulary
FSMTModel FSMTModel
......
...@@ -76,27 +76,29 @@ require 3 character language codes: ...@@ -76,27 +76,29 @@ require 3 character language codes:
.. code-block:: python .. code-block:: python
from transformers import MarianMTModel, MarianTokenizer >>> from transformers import MarianMTModel, MarianTokenizer
src_text = [ >>> src_text = [
'>>fra<< this is a sentence in english that we want to translate to french', ... '>>fra<< this is a sentence in english that we want to translate to french',
'>>por<< This should go to portuguese', ... '>>por<< This should go to portuguese',
'>>esp<< And this to Spanish' ... '>>esp<< And this to Spanish'
] >>> ]
model_name = 'Helsinki-NLP/opus-mt-en-roa' >>> model_name = 'Helsinki-NLP/opus-mt-en-roa'
tokenizer = MarianTokenizer.from_pretrained(model_name) >>> tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes) >>> print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name) ['>>zlm_Latn<<', '>>mfe<<', '>>hat<<', '>>pap<<', '>>ast<<', '>>cat<<', '>>ind<<', '>>glg<<', '>>wln<<', '>>spa<<', '>>fra<<', '>>ron<<', '>>por<<', '>>ita<<', '>>oci<<', '>>arg<<', '>>min<<']
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o português.',
# 'Y esto al español']
>>> model = MarianMTModel.from_pretrained(model_name)
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
>>> [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
["c'est une phrase en anglais que nous voulons traduire en français",
'Isto deve ir para o português.',
'Y esto al español']
Code to see available pretrained models:
Here is the code to see all available pretrained models on the hub:
.. code-block:: python .. code-block:: python
...@@ -147,21 +149,22 @@ Example of translating english to many romance languages, using old-style 2 char ...@@ -147,21 +149,22 @@ Example of translating english to many romance languages, using old-style 2 char
.. code-block::python .. code-block::python
from transformers import MarianMTModel, MarianTokenizer >>> from transformers import MarianMTModel, MarianTokenizer
src_text = [ >>> src_text = [
'>>fr<< this is a sentence in english that we want to translate to french', ... '>>fr<< this is a sentence in english that we want to translate to french',
'>>pt<< This should go to portuguese', ... '>>pt<< This should go to portuguese',
'>>es<< And this to Spanish' ... '>>es<< And this to Spanish'
] >>> ]
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE' >>> model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
tokenizer = MarianTokenizer.from_pretrained(model_name) >>> tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name) >>> model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")) >>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] >>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español'] ["c'est une phrase en anglais que nous voulons traduire en français",
'Isto deve ir para o português.',
'Y esto al español']
...@@ -176,7 +179,7 @@ MarianTokenizer ...@@ -176,7 +179,7 @@ MarianTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianTokenizer .. autoclass:: transformers.MarianTokenizer
:members: prepare_seq2seq_batch :members: as_target_tokenizer
MarianModel MarianModel
......
...@@ -34,22 +34,31 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma ...@@ -34,22 +34,31 @@ The Authors' code can be found `here <https://github.com/pytorch/fairseq/tree/ma
Training of MBart Training of MBart
_______________________________________________________________________________________________________________________ _______________________________________________________________________________________________________________________
MBart is a multilingual encoder-decoder (seq-to-seq) model primarily intended for translation task. As the model is MBart is a multilingual encoder-decoder (sequence-to-sequence) model primarily intended for translation task. As the
multilingual it expects the sequences in a different format. A special language id token is added in both the source model is multilingual it expects the sequences in a different format. A special language id token is added in both the
and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The target source and target text. The source text format is :obj:`X [eos, src_lang_code]` where :obj:`X` is the source text. The
text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used. target text format is :obj:`[tgt_lang_code] X [eos]`. :obj:`bos` is never used.
The :meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch` handles this automatically and should be used to encode The regular :meth:`~transformers.MBartTokenizer.__call__` will encode source text format, and it should be wrapped
the sequences for sequence-to-sequence fine-tuning. inside the context manager :meth:`~transformers.MBartTokenizer.as_target_tokenizer` to encode target text format.
- Supervised training - Supervised training
.. code-block:: .. code-block::
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria" >>> from transformers import MBartForConditionalGeneration, MBartTokenizer
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt") >>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass >>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt", src_lang="en_XX", tgt_lang="ro_RO")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
>>> # forward pass
>>> model(**inputs, labels=batch['labels'])
- Generation - Generation
...@@ -58,14 +67,14 @@ the sequences for sequence-to-sequence fine-tuning. ...@@ -58,14 +67,14 @@ the sequences for sequence-to-sequence fine-tuning.
.. code-block:: .. code-block::
from transformers import MBartForConditionalGeneration, MBartTokenizer >>> from transformers import MBartForConditionalGeneration, MBartTokenizer
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro") >>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
article = "UN Chief Says There Is No Military Solution in Syria" >>> article = "UN Chief Says There Is No Military Solution in Syria"
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt") >>> inputs = tokenizer(article, return_tensors="pt")
translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"]) >>> translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] >>> tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria" "Şeful ONU declară că nu există o soluţie militară în Siria"
Overview of MBart-50 Overview of MBart-50
...@@ -160,7 +169,7 @@ MBartTokenizer ...@@ -160,7 +169,7 @@ MBartTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartTokenizer .. autoclass:: transformers.MBartTokenizer
:members: build_inputs_with_special_tokens, prepare_seq2seq_batch :members: as_target_tokenizer, build_inputs_with_special_tokens
MBartTokenizerFast MBartTokenizerFast
......
...@@ -78,20 +78,20 @@ Usage Example ...@@ -78,20 +78,20 @@ Usage Example
.. code-block:: python .. code-block:: python
from transformers import PegasusForConditionalGeneration, PegasusTokenizer >>> from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch >>> import torch
src_text = [ >>> src_text = [
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""" ... """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
] >>> ]
model_name = 'google/pegasus-xsum' >>> model_name = 'google/pegasus-xsum'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' >>> device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name) >>> tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) >>> model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device) >>> batch = tokenizer(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
translated = model.generate(**batch) >>> translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) >>> tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers." >>> assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
...@@ -107,7 +107,7 @@ PegasusTokenizer ...@@ -107,7 +107,7 @@ PegasusTokenizer
warning: ``add_tokens`` does not work at the moment. warning: ``add_tokens`` does not work at the moment.
.. autoclass:: transformers.PegasusTokenizer .. autoclass:: transformers.PegasusTokenizer
:members: __call__, prepare_seq2seq_batch :members:
PegasusTokenizerFast PegasusTokenizerFast
......
...@@ -56,7 +56,7 @@ RagTokenizer ...@@ -56,7 +56,7 @@ RagTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RagTokenizer .. autoclass:: transformers.RagTokenizer
:members: prepare_seq2seq_batch :members:
Rag specific outputs Rag specific outputs
......
...@@ -104,7 +104,7 @@ T5Tokenizer ...@@ -104,7 +104,7 @@ T5Tokenizer
.. autoclass:: transformers.T5Tokenizer .. autoclass:: transformers.T5Tokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask, :members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, prepare_seq2seq_batch, save_vocabulary create_token_type_ids_from_sequences, save_vocabulary
T5TokenizerFast T5TokenizerFast
......
...@@ -71,7 +71,7 @@ tiny_model = FSMTForConditionalGeneration(config) ...@@ -71,7 +71,7 @@ tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}") print(f"num of params {tiny_model.num_parameters()}")
# Test # Test
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt") batch = tokenizer(["Making tiny model"], return_tensors="pt")
outputs = tiny_model(**batch) outputs = tiny_model(**batch)
print("test output:", len(outputs.logits[0])) print("test output:", len(outputs.logits[0]))
......
...@@ -42,7 +42,7 @@ tiny_model = FSMTForConditionalGeneration(config) ...@@ -42,7 +42,7 @@ tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}") print(f"num of params {tiny_model.num_parameters()}")
# Test # Test
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt") batch = tokenizer(["Making tiny model"], return_tensors="pt")
outputs = tiny_model(**batch) outputs = tiny_model(**batch)
print("test output:", len(outputs.logits[0])) print("test output:", len(outputs.logits[0]))
......
...@@ -522,13 +522,14 @@ MARIAN_GENERATION_EXAMPLE = r""" ...@@ -522,13 +522,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
>>> src = 'fr' # source language >>> src = 'fr' # source language
>>> trg = 'en' # target language >>> trg = 'en' # target language
>>> sample_text = "où est l'arrêt de bus ?" >>> sample_text = "où est l'arrêt de bus ?"
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' >>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
>>> model = MarianMTModel.from_pretrained(mname) >>> model = MarianMTModel.from_pretrained(model_name)
>>> tok = MarianTokenizer.from_pretrained(mname) >>> tokenizer = MarianTokenizer.from_pretrained(model_name)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference >>> batch = tokenizer([sample_text], return_tensors="pt")
>>> gen = model.generate(**batch) >>> gen = model.generate(**batch)
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?" >>> tokenizer.batch_decode(gen, skip_special_tokens=True)
"Where is the bus stop ?"
""" """
MARIAN_INPUTS_DOCSTRING = r""" MARIAN_INPUTS_DOCSTRING = r"""
......
...@@ -557,13 +557,14 @@ MARIAN_GENERATION_EXAMPLE = r""" ...@@ -557,13 +557,14 @@ MARIAN_GENERATION_EXAMPLE = r"""
>>> src = 'fr' # source language >>> src = 'fr' # source language
>>> trg = 'en' # target language >>> trg = 'en' # target language
>>> sample_text = "où est l'arrêt de bus ?" >>> sample_text = "où est l'arrêt de bus ?"
>>> mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' >>> model_name = f'Helsinki-NLP/opus-mt-{src}-{trg}'
>>> model = MarianMTModel.from_pretrained(mname) >>> model = TFMarianMTModel.from_pretrained(model_name)
>>> tok = MarianTokenizer.from_pretrained(mname) >>> tokenizer = MarianTokenizer.from_pretrained(model_name)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="tf") # don't need tgt_text for inference >>> batch = tokenizer([sample_text], return_tensors="tf")
>>> gen = model.generate(**batch) >>> gen = model.generate(**batch)
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?" >>> tokenizer.batch_decode(gen, skip_special_tokens=True)
"Where is the bus stop ?"
""" """
MARIAN_INPUTS_DOCSTRING = r""" MARIAN_INPUTS_DOCSTRING = r"""
......
...@@ -80,12 +80,15 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -80,12 +80,15 @@ class MarianTokenizer(PreTrainedTokenizer):
Examples:: Examples::
>>> from transformers import MarianTokenizer >>> from transformers import MarianTokenizer
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') >>> tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
>>> batch_enc = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt") >>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True)
>>> # keys [input_ids, attention_mask, labels]. >>> with tokenizer.as_target_tokenizer():
>>> # model(**batch) should work ... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
>>> inputs["labels"] = labels["input_ids"]
# keys [input_ids, attention_mask, labels].
>>> outputs = model(**inputs) should work
""" """
vocab_files_names = vocab_files_names vocab_files_names = vocab_files_names
......
...@@ -59,30 +59,23 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -59,30 +59,23 @@ class MBartTokenizer(XLMRobertaTokenizer):
""" """
Construct an MBART tokenizer. Construct an MBART tokenizer.
:class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer` and adds a new :class:`~transformers.MBartTokenizer` is a subclass of :class:`~transformers.XLMRobertaTokenizer`. Refer to
:meth:`~transformers.MBartTokenizer.prepare_seq2seq_batch` superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
Refer to superclass :class:`~transformers.XLMRobertaTokenizer` for usage examples and documentation concerning the
initialization parameters and other methods. initialization parameters and other methods.
.. warning::
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code> The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents. <tokens> <eos>``` for target language documents.
Examples:: Examples::
>>> from transformers import MBartTokenizer >>> from transformers import MBartTokenizer
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro') >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch( >>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt" >>> with tokenizer.as_target_tokenizer():
... ) ... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
""" """
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
...@@ -92,26 +85,38 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -92,26 +85,38 @@ class MBartTokenizer(XLMRobertaTokenizer):
prefix_tokens: List[int] = [] prefix_tokens: List[int] = []
suffix_tokens: List[int] = [] suffix_tokens: List[int] = []
def __init__(self, *args, tokenizer_file=None, **kwargs): def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs) super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
self.sp_model_size = len(self.sp_model) self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = { self.lang_code_to_id = {
code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES) code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
} }
self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()} self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
self.cur_lang_code = self.lang_code_to_id["en_XX"]
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset
self.fairseq_tokens_to_ids.update(self.lang_code_to_id) self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys()) self._additional_special_tokens = list(self.lang_code_to_id.keys())
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1 # Plus 1 for the mask token
@property
def src_lang(self) -> str:
return self._src_lang
@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]: ) -> List[int]:
...@@ -181,7 +186,6 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -181,7 +186,6 @@ class MBartTokenizer(XLMRobertaTokenizer):
) -> BatchEncoding: ) -> BatchEncoding:
self.src_lang = src_lang self.src_lang = src_lang
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self.src_lang)
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager @contextmanager
......
...@@ -70,15 +70,9 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -70,15 +70,9 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__. <https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.
:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds :class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast`. Refer to
a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`. superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning the
initialization parameters and other methods.
Refer to superclass :class:`~transformers.XLMRobertaTokenizerFast` for usage examples and documentation concerning
the initialization parameters and other methods.
.. warning::
``prepare_seq2seq_batch`` should be used to encode inputs. Other tokenizer methods like ``encode`` do not work
properly.
The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code> The tokenization method is ``<tokens> <eos> <language code>`` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents. <tokens> <eos>``` for target language documents.
...@@ -86,12 +80,13 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -86,12 +80,13 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
Examples:: Examples::
>>> from transformers import MBartTokenizerFast >>> from transformers import MBartTokenizerFast
>>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro') >>> tokenizer = MBartTokenizerFast.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch( >>> inputs = tokenizer(example_english_phrase, return_tensors="pt)
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt" >>> with tokenizer.as_target_tokenizer():
... ) ... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
>>> inputs["labels"] = labels["input_ids"]
""" """
vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"} vocab_files_names = {"vocab_file": "sentencepiece.bpe.model"}
...@@ -102,14 +97,25 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -102,14 +97,25 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
prefix_tokens: List[int] = [] prefix_tokens: List[int] = []
suffix_tokens: List[int] = [] suffix_tokens: List[int] = []
def __init__(self, *args, tokenizer_file=None, **kwargs): def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, **kwargs) super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
self.cur_lang_code = self.convert_tokens_to_ids("en_XX")
self.set_src_lang_special_tokens(kwargs.get("src_lang", "en_XX"))
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES}) self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self._src_lang)
@property
def src_lang(self) -> str:
return self._src_lang
@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)
def get_special_tokens_mask( def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]: ) -> List[int]:
...@@ -181,7 +187,6 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -181,7 +187,6 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
) -> BatchEncoding: ) -> BatchEncoding:
self.src_lang = src_lang self.src_lang = src_lang
self.tgt_lang = tgt_lang self.tgt_lang = tgt_lang
self.set_src_lang_special_tokens(self.src_lang)
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
@contextmanager @contextmanager
......
...@@ -31,13 +31,17 @@ class MT5Model(T5Model): ...@@ -31,13 +31,17 @@ class MT5Model(T5Model):
alongside usage examples. alongside usage examples.
Examples:: Examples::
>>> from transformers import MT5Model, T5Tokenizer >>> from transformers import MT5Model, T5Tokenizer
>>> model = MT5Model.from_pretrained("google/mt5-small") >>> model = MT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") >>> inputs = tokenizer(article, return_tensors="pt")
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels) >>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="pt")
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state >>> hidden_states = outputs.last_hidden_state
""" """
model_type = "mt5" model_type = "mt5"
...@@ -59,13 +63,17 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration): ...@@ -59,13 +63,17 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
appropriate documentation alongside usage examples. appropriate documentation alongside usage examples.
Examples:: Examples::
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer >>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") >>> inputs = tokenizer(article, return_tensors="pt")
>>> outputs = model(**batch) >>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="pt")
>>> outputs = model(**inputs,labels=labels["input_ids"])
>>> loss = outputs.loss >>> loss = outputs.loss
""" """
......
...@@ -31,15 +31,17 @@ class TFMT5Model(TFT5Model): ...@@ -31,15 +31,17 @@ class TFMT5Model(TFT5Model):
documentation alongside usage examples. documentation alongside usage examples.
Examples:: Examples::
>>> from transformers import TFMT5Model, T5Tokenizer >>> from transformers import TFMT5Model, T5Tokenizer
>>> model = TFMT5Model.from_pretrained("google/mt5-small") >>> model = TFMT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf") >>> inputs = tokenizer(article, return_tensors="tf")
>>> batch["decoder_input_ids"] = batch["labels"] >>> with tokenizer.as_target_tokenizer():
>>> del batch["labels"] ... labels = tokenizer(summary, return_tensors="tf")
>>> outputs = model(batch)
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
>>> hidden_states = outputs.last_hidden_state >>> hidden_states = outputs.last_hidden_state
""" """
model_type = "mt5" model_type = "mt5"
...@@ -52,13 +54,17 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration): ...@@ -52,13 +54,17 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
appropriate documentation alongside usage examples. appropriate documentation alongside usage examples.
Examples:: Examples::
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer >>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") >>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien." >>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf") >>> inputs = tokenizer(article, return_tensors="tf")
>>> outputs = model(batch) >>> with tokenizer.as_target_tokenizer():
... labels = tokenizer(summary, return_tensors="tf")
>>> outputs = model(**inputs,labels=labels["input_ids"])
>>> loss = outputs.loss >>> loss = outputs.loss
""" """
......
...@@ -550,10 +550,8 @@ class RagModel(RagPreTrainedModel): ...@@ -550,10 +550,8 @@ class RagModel(RagPreTrainedModel):
>>> # initialize with RagRetriever to do everything in one forward call >>> # initialize with RagRetriever to do everything in one forward call
>>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever) >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> input_ids = input_dict["input_ids"] >>> outputs = model(input_ids=inputs["input_ids"])
>>> outputs = model(input_ids=input_ids)
""" """
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
...@@ -752,9 +750,12 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -752,9 +750,12 @@ class RagSequenceForGeneration(RagPreTrainedModel):
>>> # initialize with RagRetriever to do everything in one forward call >>> # initialize with RagRetriever to do everything in one forward call
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> input_ids = input_dict["input_ids"] >>> with tokenizer.as_target_tokenizer():
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) ... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> # or use retriever separately >>> # or use retriever separately
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
...@@ -764,7 +765,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -764,7 +765,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt") >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1) >>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
>>> # 3. Forward to generator >>> # 3. Forward to generator
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
""" """
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
...@@ -1203,9 +1204,12 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1203,9 +1204,12 @@ class RagTokenForGeneration(RagPreTrainedModel):
>>> # initialize with RagRetriever to do everything in one forward call >>> # initialize with RagRetriever to do everything in one forward call
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
>>> input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt") >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
>>> input_ids = input_dict["input_ids"] >>> with tokenizer.as_target_tokenizer():
>>> outputs = model(input_ids=input_ids, labels=input_dict["labels"]) ... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
>>> input_ids = inputs["input_ids"]
>>> labels = targets["input_ids"]
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> # or use retriever separately >>> # or use retriever separately
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True) >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
...@@ -1215,7 +1219,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1215,7 +1219,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
>>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt") >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
>>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1) >>> doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
>>> # 3. Forward to generator >>> # 3. Forward to generator
>>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"]) >>> outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=labels)
>>> # or directly generate >>> # or directly generate
>>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores) >>> generated = model.generate(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for RAG.""" """Tokenization classes for RAG."""
import os import os
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional from typing import List, Optional
...@@ -88,6 +89,13 @@ class RagTokenizer: ...@@ -88,6 +89,13 @@ class RagTokenizer:
truncation: bool = True, truncation: bool = True,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
warnings.warn(
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
"details",
FutureWarning,
)
if max_length is None: if max_length is None:
max_length = self.current_tokenizer.model_max_length max_length = self.current_tokenizer.model_max_length
model_inputs = self( model_inputs = self(
......
...@@ -3303,6 +3303,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -3303,6 +3303,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed. The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys. Otherwise, input_ids, attention_mask will be the only keys.
""" """
warnings.warn(
"`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
"regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
"context manager to prepare your targets. See the documentation of your specific tokenizer for more "
"details",
FutureWarning,
)
# mBART-specific kwargs that should be ignored by other models. # mBART-specific kwargs that should be ignored by other models.
kwargs.pop("src_lang", None) kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None) kwargs.pop("tgt_lang", None)
......
...@@ -354,9 +354,7 @@ class MarianIntegrationTest(unittest.TestCase): ...@@ -354,9 +354,7 @@ class MarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer.prepare_seq2seq_batch( model_inputs = self.tokenizer(self.src_text, return_tensors="pt", **tokenizer_kwargs).to(torch_device)
src_texts=self.src_text, return_tensors="pt", **tokenizer_kwargs
).to(torch_device)
self.assertEqual(self.model.device, model_inputs.input_ids.device) self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
...@@ -373,9 +371,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -373,9 +371,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected_ids = [38, 121, 14, 697, 38848, 0] expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt, return_tensors="pt").to( model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
torch_device with self.tokenizer.as_target_tokenizer():
) targets = self.tokenizer(tgt, return_tensors="pt")
model_inputs["labels"] = targets["input_ids"].to(torch_device)
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
...@@ -397,16 +396,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -397,16 +396,12 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
def test_unk_support(self): def test_unk_support(self):
t = self.tokenizer t = self.tokenizer
ids = t.prepare_seq2seq_batch(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist() ids = t(["||"], return_tensors="pt").to(torch_device).input_ids[0].tolist()
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id] expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
self.assertEqual(expected, ids) self.assertEqual(expected, ids)
def test_pad_not_split(self): def test_pad_not_split(self):
input_ids_w_pad = ( input_ids_w_pad = self.tokenizer(["I am a small frog <pad>"], return_tensors="pt").input_ids[0].tolist()
self.tokenizer.prepare_seq2seq_batch(["I am a small frog <pad>"], return_tensors="pt")
.input_ids[0]
.tolist()
)
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
self.assertListEqual(expected_w_pad, input_ids_w_pad) self.assertListEqual(expected_w_pad, input_ids_w_pad)
......
...@@ -349,7 +349,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -349,7 +349,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow @slow
def test_enro_generate_one(self): def test_enro_generate_one(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch( batch: BatchEncoding = self.tokenizer(
["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt" ["UN Chief Says There Is No Military Solution in Syria"], return_tensors="pt"
).to(torch_device) ).to(torch_device)
translated_tokens = self.model.generate(**batch) translated_tokens = self.model.generate(**batch)
...@@ -359,9 +359,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -359,9 +359,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow @slow
def test_enro_generate_batch(self): def test_enro_generate_batch(self):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text, return_tensors="pt").to( batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt").to(torch_device)
torch_device
)
translated_tokens = self.model.generate(**batch) translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
assert self.tgt_text == decoded assert self.tgt_text == decoded
...@@ -412,7 +410,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -412,7 +410,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@unittest.skip("This test is broken, still generates english") @unittest.skip("This test is broken, still generates english")
def test_cc25_generate(self): def test_cc25_generate(self):
inputs = self.tokenizer.prepare_seq2seq_batch([self.src_text[0]], return_tensors="pt").to(torch_device) inputs = self.tokenizer([self.src_text[0]], return_tensors="pt").to(torch_device)
translated_tokens = self.model.generate( translated_tokens = self.model.generate(
input_ids=inputs["input_ids"].to(torch_device), input_ids=inputs["input_ids"].to(torch_device),
decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"], decoder_start_token_id=self.tokenizer.lang_code_to_id["ro_RO"],
...@@ -422,9 +420,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -422,9 +420,7 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow @slow
def test_fill_mask(self): def test_fill_mask(self):
inputs = self.tokenizer.prepare_seq2seq_batch(["One of the best <mask> I ever read!"], return_tensors="pt").to( inputs = self.tokenizer(["One of the best <mask> I ever read!"], return_tensors="pt").to(torch_device)
torch_device
)
outputs = self.model.generate( outputs = self.model.generate(
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1 inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment