"docs/source/vscode:/vscode.git/clone" did not exist on "df2af6d8b8765b1ac2cda12d2ece09bf7240fba8"
Unverified Commit 8aa67fc1 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing mbart50 with `return_tensors` argument too. (#13301)

* Fixing mbart50 with `return_tensors` argument too.

* Adding mbart50 tokenization tests.
parent b89a964d
...@@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -304,12 +304,14 @@ class MBart50Tokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function""" """Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None: if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs
......
...@@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): ...@@ -245,12 +245,14 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
) )
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs): def _build_translation_inputs(
self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
):
"""Used by translation pipeline, to prepare inputs for the generate function""" """Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None: if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model") raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs) inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang) tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id inputs["forced_bos_token_id"] = tgt_lang_id
return inputs return inputs
......
...@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): ...@@ -229,7 +229,9 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
@require_torch @require_torch
def test_tokenizer_translation(self): def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR") inputs = self.tokenizer._build_translation_inputs(
"A test", return_tensors="pt", src_lang="en_XX", tgt_lang="ar_AR"
)
self.assertEqual( self.assertEqual(
nested_simplify(inputs), nested_simplify(inputs),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment