Unverified Commit 92970c0c authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Enabling multilingual models for translation pipelines. (#10536)



* [WIP] Enabling multilingual models for translation pipelines.

* decoder_input_ids -> forced_bos_token_id

* Improve docstring.

* Rebase

* Fixing 2 bugs

- Type token_ids coming from `_parse_and_tokenize`
- Wrong index from tgt_lang.

* Fixing black version.

* Adding tests for _build_translation_inputs and add them for all
tokenizers.

* Mbart actually puts the lang code at the end.

* Fixing m2m100.

* Adding TF support to `deep_round`.

* Update src/transformers/pipelines/text2text_generation.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding one line comment.

* Fixing M2M100 `_build_translation_input_ids`, and fix the call site.

* Fixing tests + deep_round -> nested_simplify
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 5254220e
...@@ -288,6 +288,16 @@ class M2M100Tokenizer(PreTrainedTokenizer): ...@@ -288,6 +288,16 @@ class M2M100Tokenizer(PreTrainedTokenizer):
self.set_src_lang_special_tokens(self.src_lang) self.set_src_lang_special_tokens(self.src_lang)
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.get_lang_id(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
@contextmanager @contextmanager
def as_target_tokenizer(self): def as_target_tokenizer(self):
""" """
......
...@@ -186,6 +186,16 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -186,6 +186,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
......
...@@ -278,6 +278,16 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -278,6 +278,16 @@ class MBart50Tokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
......
...@@ -241,6 +241,16 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): ...@@ -241,6 +241,16 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)), special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
) )
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory") logger.error(f"Vocabulary path ({save_directory}) should be a directory")
......
...@@ -160,6 +160,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast): ...@@ -160,6 +160,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency # We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def prepare_seq2seq_batch( def prepare_seq2seq_batch(
self, self,
src_texts: List[str], src_texts: List[str],
......
...@@ -616,7 +616,10 @@ class Pipeline(_ScikitCompat): ...@@ -616,7 +616,10 @@ class Pipeline(_ScikitCompat):
Return: Return:
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device. :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
""" """
return {name: tensor.to(self.device) for name, tensor in inputs.items()} return {
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
for name, tensor in inputs.items()
}
def check_model_type(self, supported_models: Union[List[str], dict]): def check_model_type(self, supported_models: Union[List[str], dict]):
""" """
......
from typing import Optional
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy from ..tokenization_utils import TruncationStrategy
from ..utils import logging from ..utils import logging
...@@ -50,6 +52,28 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -50,6 +52,28 @@ class Text2TextGenerationPipeline(Pipeline):
""" """
return True return True
def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
args = ([prefix + arg for arg in args[0]],)
padding = True
elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
inputs = super()._parse_and_tokenize(*args, padding=padding, truncation=truncation)
# This is produced by tokenizers but is an invalid generate kwargs
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
return inputs
def __call__( def __call__(
self, self,
*args, *args,
...@@ -88,25 +112,13 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -88,25 +112,13 @@ class Text2TextGenerationPipeline(Pipeline):
""" """
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
args = ([prefix + arg for arg in args[0]],)
padding = True
elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
with self.device_placement(): with self.device_placement():
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation) inputs = self._parse_and_tokenize(*args, truncation=truncation)
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
def _generate(
self, inputs, return_tensors: bool, return_text: bool, clean_up_tokenization_spaces: bool, generate_kwargs
):
if self.framework == "pt": if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs) inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1] input_length = inputs["input_ids"].shape[-1]
...@@ -117,9 +129,9 @@ class Text2TextGenerationPipeline(Pipeline): ...@@ -117,9 +129,9 @@ class Text2TextGenerationPipeline(Pipeline):
max_length = generate_kwargs.get("max_length", self.model.config.max_length) max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length) self.check_inputs(input_length, min_length, max_length)
generate_kwargs.update(inputs)
generations = self.model.generate( generations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generate_kwargs, **generate_kwargs,
) )
results = [] results = []
...@@ -226,6 +238,23 @@ class TranslationPipeline(Text2TextGenerationPipeline): ...@@ -226,6 +238,23 @@ class TranslationPipeline(Text2TextGenerationPipeline):
# Used in the return key of the pipeline. # Used in the return key of the pipeline.
return_name = "translation" return_name = "translation"
src_lang: Optional[str] = None
tgt_lang: Optional[str] = None
def __init__(self, *args, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, **kwargs)
if src_lang is not None:
self.src_lang = src_lang
if tgt_lang is not None:
self.tgt_lang = tgt_lang
if src_lang is None and tgt_lang is None:
# Backward compatibility, direct arguments use is preferred.
task = kwargs.get("task", "")
items = task.split("_")
if task and len(items) == 4:
# translation, XX, to YY
self.src_lang = items[1]
self.tgt_lang = items[3]
def check_inputs(self, input_length: int, min_length: int, max_length: int): def check_inputs(self, input_length: int, min_length: int, max_length: int):
if input_length > 0.9 * max_length: if input_length > 0.9 * max_length:
...@@ -233,8 +262,27 @@ class TranslationPipeline(Text2TextGenerationPipeline): ...@@ -233,8 +262,27 @@ class TranslationPipeline(Text2TextGenerationPipeline):
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider " f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
"increasing your max_length manually, e.g. translator('...', max_length=400)" "increasing your max_length manually, e.g. translator('...', max_length=400)"
) )
return True
def __call__(self, *args, **kwargs): def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
if getattr(self.tokenizer, "_build_translation_inputs", None):
return self.tokenizer._build_translation_inputs(
*args, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
)
else:
return super()._parse_and_tokenize(*args, truncation=truncation)
def __call__(
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
src_lang=None,
tgt_lang=None,
**generate_kwargs
):
r""" r"""
Translate the text(s) given as inputs. Translate the text(s) given as inputs.
...@@ -247,6 +295,12 @@ class TranslationPipeline(Text2TextGenerationPipeline): ...@@ -247,6 +295,12 @@ class TranslationPipeline(Text2TextGenerationPipeline):
Whether or not to include the decoded texts in the outputs. Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output. Whether or not to clean up the potential extra spaces in the text output.
src_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
The language of the input. Might be required for multilingual models. Will not have any effect for
single pair translation models
tgt_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
The language of the desired output. Might be required for multilingual models. Will not have any effect
for single pair translation models
generate_kwargs: generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__). corresponding to your framework `here <./model.html#generative-models>`__).
...@@ -258,4 +312,10 @@ class TranslationPipeline(Text2TextGenerationPipeline): ...@@ -258,4 +312,10 @@ class TranslationPipeline(Text2TextGenerationPipeline):
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``) - **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the translation. -- The token ids of the translation.
""" """
return super().__call__(*args, **kwargs) assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
src_lang = src_lang if src_lang is not None else self.src_lang
tgt_lang = tgt_lang if tgt_lang is not None else self.tgt_lang
with self.device_placement():
inputs = self._parse_and_tokenize(*args, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
...@@ -361,6 +361,9 @@ if is_torch_available(): ...@@ -361,6 +361,9 @@ if is_torch_available():
else: else:
torch_device = None torch_device = None
if is_tf_available():
import tensorflow as tf
def require_torch_gpu(test_case): def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch. """ """Decorator marking a test that requires CUDA and PyTorch. """
...@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False ...@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
raise RuntimeError(f"'{cmd_str}' produced no output.") raise RuntimeError(f"'{cmd_str}' produced no output.")
return result return result
def nested_simplify(obj, decimals=3):
"""
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
within tests.
"""
from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, (dict, BatchEncoding)):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int)):
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist())
elif is_tf_available() and tf.is_tensor(obj):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
...@@ -17,11 +17,15 @@ import unittest ...@@ -17,11 +17,15 @@ import unittest
import pytest import pytest
from transformers import pipeline from transformers import pipeline
from transformers.testing_utils import is_pipeline_test, require_torch, slow from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
from .test_pipelines_common import MonoInputPipelineCommonMixin from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "translation_en_to_de" pipeline_task = "translation_en_to_de"
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
...@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): ...@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
pipeline(task="translation_cn_to_ar") pipeline(task="translation_cn_to_ar")
# but we do for this one # but we do for this one
pipeline(task="translation_en_to_de") translator = pipeline(task="translation_en_to_de")
self.assertEquals(translator.src_lang, "en")
self.assertEquals(translator.tgt_lang, "de")
@require_torch
@slow
def test_multilingual_translation(self):
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translator = pipeline(task="translation", model=model, tokenizer=tokenizer)
# Missing src_lang, tgt_lang
with self.assertRaises(ValueError):
translator("This is a test")
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])
# src_lang, tgt_lang can be defined at pipeline call time
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
outputs = translator("This is a test")
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
@require_torch @require_torch
def test_translation_on_odd_language(self): def test_translation_on_odd_language(self):
model = "patrickvonplaten/t5-tiny-random" model = "patrickvonplaten/t5-tiny-random"
pipeline(task="translation_cn_to_ar", model=model) translator = pipeline(task="translation_cn_to_ar", model=model)
self.assertEquals(translator.src_lang, "cn")
self.assertEquals(translator.tgt_lang, "ar")
@require_torch @require_torch
def test_translation_default_language_selection(self): def test_translation_default_language_selection(self):
...@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): ...@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"): with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
nlp = pipeline(task="translation", model=model) nlp = pipeline(task="translation", model=model)
self.assertEqual(nlp.task, "translation_en_to_de") self.assertEqual(nlp.task, "translation_en_to_de")
self.assertEquals(nlp.src_lang, "en")
self.assertEquals(nlp.tgt_lang, "de")
@require_torch @require_torch
def test_translation_with_no_language_no_model_fails(self): def test_translation_with_no_language_no_model_fails(self):
......
...@@ -20,7 +20,7 @@ from shutil import copyfile ...@@ -20,7 +20,7 @@ from shutil import copyfile
from transformers import M2M100Tokenizer, is_torch_available from transformers import M2M100Tokenizer, is_torch_available
from transformers.file_utils import is_sentencepiece_available from transformers.file_utils import is_sentencepiece_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
if is_sentencepiece_available(): if is_sentencepiece_available():
...@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): ...@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")]) self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id]) self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)]) self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
self.assertEqual(
nested_simplify(inputs),
{
# en_XX, A, test, EOS
"input_ids": [[128022, 58, 4183, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 128006,
},
)
...@@ -17,7 +17,7 @@ import tempfile ...@@ -17,7 +17,7 @@ import tempfile
import unittest import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
...@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase): ...@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10) self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[62, 3034, 2, 250004]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 250001,
},
)
...@@ -17,7 +17,7 @@ import tempfile ...@@ -17,7 +17,7 @@ import tempfile
import unittest import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
...@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase): ...@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
self.assertEqual(batch.input_ids.shape[1], 3) self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10) self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(
nested_simplify(inputs),
{
# en_XX, A, test, EOS
"input_ids": [[250004, 62, 3034, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 250001,
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment