"PyTorch/NAS/git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "0fc002dfc863089e33ea2dee33b0827046e4d174"
Unverified Commit 3095ee9d authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Tokenizers should be framework agnostic (#8599)



* Tokenizers should be framework agnostic

* Run the slow tests

* Not testing

* Fix documentation

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7f3b41a3
...@@ -16,7 +16,7 @@ on: ...@@ -16,7 +16,7 @@ on:
jobs: jobs:
run_tests_torch_gpu: run_tests_torch_gpu:
runs-on: [self-hosted, single-gpu] runs-on: [self-hosted, gpu, single-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Python version - name: Python version
...@@ -86,7 +86,7 @@ jobs: ...@@ -86,7 +86,7 @@ jobs:
run_tests_tf_gpu: run_tests_tf_gpu:
runs-on: [self-hosted, single-gpu] runs-on: [self-hosted, gpu, single-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Python version - name: Python version
...@@ -154,7 +154,7 @@ jobs: ...@@ -154,7 +154,7 @@ jobs:
path: reports path: reports
run_tests_torch_multi_gpu: run_tests_torch_multi_gpu:
runs-on: [self-hosted, multi-gpu] runs-on: [self-hosted, gpu, multi-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Python version - name: Python version
...@@ -213,7 +213,7 @@ jobs: ...@@ -213,7 +213,7 @@ jobs:
path: reports path: reports
run_tests_tf_multi_gpu: run_tests_tf_multi_gpu:
runs-on: [self-hosted, multi-gpu] runs-on: [self-hosted, gpu, multi-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Python version - name: Python version
......
...@@ -9,13 +9,14 @@ on: ...@@ -9,13 +9,14 @@ on:
push: push:
branches: branches:
- ci_* - ci_*
- framework-agnostic-tokenizers
repository_dispatch: repository_dispatch:
schedule: schedule:
- cron: "0 0 * * *" - cron: "0 0 * * *"
jobs: jobs:
run_all_tests_torch_gpu: run_all_tests_torch_gpu:
runs-on: [self-hosted, single-gpu] runs-on: [self-hosted, gpu, single-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
...@@ -109,7 +110,7 @@ jobs: ...@@ -109,7 +110,7 @@ jobs:
run_all_tests_tf_gpu: run_all_tests_tf_gpu:
runs-on: [self-hosted, single-gpu] runs-on: [self-hosted, gpu, single-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
...@@ -188,7 +189,7 @@ jobs: ...@@ -188,7 +189,7 @@ jobs:
path: reports path: reports
run_all_tests_torch_multi_gpu: run_all_tests_torch_multi_gpu:
runs-on: [self-hosted, multi-gpu] runs-on: [self-hosted, gpu, multi-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
...@@ -279,7 +280,7 @@ jobs: ...@@ -279,7 +280,7 @@ jobs:
path: reports path: reports
run_all_tests_tf_multi_gpu: run_all_tests_tf_multi_gpu:
runs-on: [self-hosted, multi-gpu] runs-on: [self-hosted, gpu, multi-gpu]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
......
...@@ -78,7 +78,7 @@ require 3 character language codes: ...@@ -78,7 +78,7 @@ require 3 character language codes:
tokenizer = MarianTokenizer.from_pretrained(model_name) tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes) print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text)) translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français", # ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o português.', # 'Isto deve ir para o português.',
...@@ -150,7 +150,7 @@ Example of translating english to many romance languages, using old-style 2 char ...@@ -150,7 +150,7 @@ Example of translating english to many romance languages, using old-style 2 char
print(tokenizer.supported_language_codes) print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text)) translated = model.generate(**tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt"))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español'] # ["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español']
......
...@@ -44,7 +44,7 @@ the sequences for sequence-to-sequence fine-tuning. ...@@ -44,7 +44,7 @@ the sequences for sequence-to-sequence fine-tuning.
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria" example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian) batch = tokenizer.prepare_seq2seq_batch(example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt")
model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass model(input_ids=batch['input_ids'], labels=batch['labels']) # forward pass
- Generation - Generation
...@@ -58,7 +58,7 @@ the sequences for sequence-to-sequence fine-tuning. ...@@ -58,7 +58,7 @@ the sequences for sequence-to-sequence fine-tuning.
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro") model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro") tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
article = "UN Chief Says There Is No Military Solution in Syria" article = "UN Chief Says There Is No Military Solution in Syria"
batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX") batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], src_lang="en_XX", return_tensors="pt")
translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"]) translated_tokens = model.generate(**batch, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria" assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
......
...@@ -78,7 +78,7 @@ Usage Example ...@@ -78,7 +78,7 @@ Usage Example
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name) tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest').to(torch_device) batch = tokenizer.prepare_seq2seq_batch(src_text, truncation=True, padding='longest', return_tensors="pt").to(torch_device)
translated = model.generate(**batch) translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers." assert tgt_text[0] == "California's largest electricity provider has turned off power to hundreds of thousands of customers."
......
...@@ -11,7 +11,7 @@ tokenizer = PegasusTokenizer.from_pretrained(model_name) ...@@ -11,7 +11,7 @@ tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
def get_response(input_text,num_return_sequences): def get_response(input_text,num_return_sequences):
batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60).to(torch_device) batch = tokenizer.prepare_seq2seq_batch([input_text],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
translated = model.generate(**batch,max_length=60,num_beams=10, num_return_sequences=num_return_sequences, temperature=1.5) translated = model.generate(**batch,max_length=60,num_beams=10, num_return_sequences=num_return_sequences, temperature=1.5)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text return tgt_text
......
...@@ -12,7 +12,7 @@ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_dev ...@@ -12,7 +12,7 @@ model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_dev
def get_answer(question, context): def get_answer(question, context):
input_text = "question: %s text: %s" % (question,context) input_text = "question: %s text: %s" % (question,context)
batch = tokenizer.prepare_seq2seq_batch([input_text], truncation=True, padding='longest').to(torch_device) batch = tokenizer.prepare_seq2seq_batch([input_text], truncation=True, padding='longest', return_tensors="pt").to(torch_device)
translated = model.generate(**batch) translated = model.generate(**batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True) tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
return tgt_text[0] return tgt_text[0]
......
...@@ -58,7 +58,7 @@ tiny_model = FSMTForConditionalGeneration(config) ...@@ -58,7 +58,7 @@ tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}") print(f"num of params {tiny_model.num_parameters()}")
# Test # Test
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"]) batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
outputs = tiny_model(**batch) outputs = tiny_model(**batch)
print("test output:", len(outputs.logits[0])) print("test output:", len(outputs.logits[0]))
......
...@@ -29,7 +29,7 @@ tiny_model = FSMTForConditionalGeneration(config) ...@@ -29,7 +29,7 @@ tiny_model = FSMTForConditionalGeneration(config)
print(f"num of params {tiny_model.num_parameters()}") print(f"num of params {tiny_model.num_parameters()}")
# Test # Test
batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"]) batch = tokenizer.prepare_seq2seq_batch(["Making tiny model"], return_tensors="pt")
outputs = tiny_model(**batch) outputs = tiny_model(**batch)
print("test output:", len(outputs.logits[0])) print("test output:", len(outputs.logits[0]))
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from typing import List, Optional from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding from transformers import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from ..roberta.tokenization_roberta import RobertaTokenizer from ..roberta.tokenization_roberta import RobertaTokenizer
...@@ -54,6 +56,7 @@ class BartTokenizer(RobertaTokenizer): ...@@ -54,6 +56,7 @@ class BartTokenizer(RobertaTokenizer):
"merges_file": {m: merges_url for m in _all_bart_models}, "merges_file": {m: merges_url for m in _all_bart_models},
} }
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
...@@ -61,70 +64,10 @@ class BartTokenizer(RobertaTokenizer): ...@@ -61,70 +64,10 @@ class BartTokenizer(RobertaTokenizer):
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "None", return_tensors: str = None,
truncation=True, truncation=True,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **labels** -- List of token ids for tgt_texts
The full set of keys ``[input_ids, attention_mask, labels]``, will only be returned if tgt_texts is passed.
Otherwise, input_ids, attention_mask will be the only keys.
"""
kwargs.pop("src_lang", None) kwargs.pop("src_lang", None)
kwargs.pop("tgt_lang", None) kwargs.pop("tgt_lang", None)
if max_length is None: if max_length is None:
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from typing import List, Optional from typing import List, Optional
from ...tokenization_utils_base import BatchEncoding from transformers import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
from .tokenization_bart import BartTokenizer from .tokenization_bart import BartTokenizer
...@@ -49,6 +51,7 @@ class BartTokenizerFast(RobertaTokenizerFast): ...@@ -49,6 +51,7 @@ class BartTokenizerFast(RobertaTokenizerFast):
} }
slow_tokenizer_class = BartTokenizer slow_tokenizer_class = BartTokenizer
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
...@@ -56,72 +59,10 @@ class BartTokenizerFast(RobertaTokenizerFast): ...@@ -56,72 +59,10 @@ class BartTokenizerFast(RobertaTokenizerFast):
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "None", return_tensors: Optional[str] = None,
truncation=True, truncation=True,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts). If
left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum length
is required by one of the truncation/padding parameters. If the model has no specific maximum input
length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries). If left unset or
set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the
decoder. This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, will only
be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None: if max_length is None:
max_length = self.model_max_length max_length = self.model_max_length
model_inputs: BatchEncoding = self( model_inputs: BatchEncoding = self(
......
...@@ -491,7 +491,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -491,7 +491,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
return_tensors: str = "pt", return_tensors: Optional[str] = None,
truncation=True, truncation=True,
padding="longest", padding="longest",
**unused, **unused,
......
...@@ -41,7 +41,7 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -41,7 +41,7 @@ class MarianMTModel(BartForConditionalGeneration):
>>> model = MarianMTModel.from_pretrained(mname) >>> model = MarianMTModel.from_pretrained(mname)
>>> tok = MarianTokenizer.from_pretrained(mname) >>> tok = MarianTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text]) # don't need tgt_text for inference >>> batch = tok.prepare_seq2seq_batch(src_texts=[sample_text], return_tensors="pt") # don't need tgt_text for inference
>>> gen = model.generate(**batch) # for forward pass: model(**batch) >>> gen = model.generate(**batch) # for forward pass: model(**batch)
>>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?" >>> words: List[str] = tok.batch_decode(gen, skip_special_tokens=True) # returns "Where is the bus stop ?"
......
...@@ -70,7 +70,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -70,7 +70,7 @@ class MarianTokenizer(PreTrainedTokenizer):
>>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de') >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
>>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."] >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
>>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts) >>> batch_enc: BatchEncoding = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt")
>>> # keys [input_ids, attention_mask, labels]. >>> # keys [input_ids, attention_mask, labels].
>>> # model(**batch) should work >>> # model(**batch) should work
""" """
...@@ -175,7 +175,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -175,7 +175,7 @@ class MarianTokenizer(PreTrainedTokenizer):
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
return_tensors: str = "pt", return_tensors: Optional[str] = None,
truncation=True, truncation=True,
padding="longest", padding="longest",
**unused, **unused,
......
...@@ -22,7 +22,7 @@ class MBartForConditionalGeneration(BartForConditionalGeneration): ...@@ -22,7 +22,7 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro") >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro") >>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro")
>>> article = "UN Chief Says There Is No Military Solution in Syria" >>> article = "UN Chief Says There Is No Military Solution in Syria"
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article]) >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], return_tensors="pt")
>>> translated_tokens = model.generate(**batch) >>> translated_tokens = model.generate(**batch)
>>> translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] >>> translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
>>> assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria" >>> assert translation == "Şeful ONU declară că nu există o soluţie militară în Siria"
......
...@@ -81,7 +81,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -81,7 +81,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch( >>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
... ) ... )
""" """
...@@ -183,7 +183,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -183,7 +183,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
truncation: bool = True, truncation: bool = True,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "pt", return_tensors: Optional[str] = None,
add_prefix_space: bool = False, # ignored add_prefix_space: bool = False, # ignored
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
......
...@@ -89,7 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -89,7 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria" >>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria" >>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
>>> batch: dict = tokenizer.prepare_seq2seq_batch( >>> batch: dict = tokenizer.prepare_seq2seq_batch(
... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian ... example_english_phrase, src_lang="en_XX", tgt_lang="ro_RO", tgt_texts=expected_translation_romanian, return_tensors="pt"
... ) ... )
""" """
...@@ -181,7 +181,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -181,7 +181,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
truncation: bool = True, truncation: bool = True,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "pt", return_tensors: str = None,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
if max_length is None: if max_length is None:
......
...@@ -38,7 +38,7 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration): ...@@ -38,7 +38,7 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
>>> model = PegasusForConditionalGeneration.from_pretrained(mname) >>> model = PegasusForConditionalGeneration.from_pretrained(mname)
>>> tok = PegasusTokenizer.from_pretrained(mname) >>> tok = PegasusTokenizer.from_pretrained(mname)
>>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE]) # don't need tgt_text for inference >>> batch = tok.prepare_seq2seq_batch(src_texts=[PGE_ARTICLE], return_tensors="pt") # don't need tgt_text for inference
>>> gen = model.generate(**batch) # for forward pass: model(**batch) >>> gen = model.generate(**batch) # for forward pass: model(**batch)
>>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True) >>> summary: List[str] = tok.batch_decode(gen, skip_special_tokens=True)
>>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers." >>> assert summary == "California's largest electricity provider has turned off power to tens of thousands of customers."
......
...@@ -134,7 +134,7 @@ class PegasusTokenizer(ReformerTokenizer): ...@@ -134,7 +134,7 @@ class PegasusTokenizer(ReformerTokenizer):
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
return_tensors: str = "pt", return_tensors: str = None,
truncation=True, truncation=True,
padding="longest", padding="longest",
**unused, **unused,
......
...@@ -95,7 +95,7 @@ class PegasusTokenizerFast(ReformerTokenizerFast): ...@@ -95,7 +95,7 @@ class PegasusTokenizerFast(ReformerTokenizerFast):
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None, max_length: Optional[int] = None,
max_target_length: Optional[int] = None, max_target_length: Optional[int] = None,
return_tensors: str = "pt", return_tensors: str = None,
truncation=True, truncation=True,
padding="longest", padding="longest",
**unused, **unused,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment