Unverified Commit 8aa67fc1 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing mbart50 with `return_tensors` argument too. (#13301)

* Fixing mbart50 with `return_tensors` argument too.

* Adding mbart50 tokenization tests.
parent b89a964d
......@@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
......
......@@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
)
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
......
......@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
inputs = self.tokenizer._build_translation_inputs(
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
)
self.assertEqual(
nested_simplify(inputs),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment