Unverified Commit 92970c0c authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Enabling multilingual models for translation pipelines. (#10536)



* [WIP] Enabling multilingual models for translation pipelines.

* decoder_input_ids -> forced_bos_token_id

* Improve docstring.

* Rebase

* Fixing 2 bugs

- Type token_ids coming from `_parse_and_tokenize`
- Wrong index from tgt_lang.

* Fixing black version.

* Adding tests for _build_translation_inputs and add them for all
tokenizers.

* Mbart actually puts the lang code at the end.

* Fixing m2m100.

* Adding TF support to `deep_round`.

* Update src/transformers/pipelines/text2text_generation.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding one line comment.

* Fixing M2M100 `_build_translation_input_ids`, and fix the call site.

* Fixing tests + deep_round -> nested_simplify
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 5254220e
......@@ -288,6 +288,16 @@ class M2M100Tokenizer(PreTrainedTokenizer):
self.set_src_lang_special_tokens(self.src_lang)
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.get_lang_id(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
@contextmanager
def as_target_tokenizer(self):
"""
......
......@@ -186,6 +186,16 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def prepare_seq2seq_batch(
self,
src_texts: List[str],
......
......@@ -278,6 +278,16 @@ class MBart50Tokenizer(PreTrainedTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def prepare_seq2seq_batch(
self,
src_texts: List[str],
......
......@@ -241,6 +241,16 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
special_tokens=list(zip(prefix_tokens_str + suffix_tokens_str, self.prefix_tokens + self.suffix_tokens)),
)
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
......
......@@ -160,6 +160,16 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens
def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs):
"""Used by translation pipeline, to prepare inputs for the generate function"""
if src_lang is None or tgt_lang is None:
raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", **extra_kwargs)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs
def prepare_seq2seq_batch(
self,
src_texts: List[str],
......
......@@ -616,7 +616,10 @@ class Pipeline(_ScikitCompat):
Return:
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
"""
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
return {
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
for name, tensor in inputs.items()
}
def check_model_type(self, supported_models: Union[List[str], dict]):
"""
......
from typing import Optional
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
......@@ -50,6 +52,28 @@ class Text2TextGenerationPipeline(Pipeline):
"""
return True
def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
args = ([prefix + arg for arg in args[0]],)
padding = True
elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
inputs = super()._parse_and_tokenize(*args, padding=padding, truncation=truncation)
# This is produced by tokenizers but is an invalid generate kwargs
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
return inputs
def __call__(
self,
*args,
......@@ -88,53 +112,41 @@ class Text2TextGenerationPipeline(Pipeline):
"""
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
assert (
self.tokenizer.pad_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id when using a batch input"
args = ([prefix + arg for arg in args[0]],)
padding = True
with self.device_placement():
inputs = self._parse_and_tokenize(*args, truncation=truncation)
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
def _generate(
self, inputs, return_tensors: bool, return_text: bool, clean_up_tokenization_spaces: bool, generate_kwargs
):
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
with self.device_placement():
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
input_length = inputs["input_ids"].shape[-1]
elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy()
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length)
generations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generate_kwargs,
)
results = []
for generation in generations:
record = {}
if return_tensors:
record[f"{self.return_name}_token_ids"] = generation
if return_text:
record[f"{self.return_name}_text"] = self.tokenizer.decode(
generation,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
min_length = generate_kwargs.get("min_length", self.model.config.min_length)
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length)
generate_kwargs.update(inputs)
generations = self.model.generate(
**generate_kwargs,
)
results = []
for generation in generations:
record = {}
if return_tensors:
record[f"{self.return_name}_token_ids"] = generation
if return_text:
record[f"{self.return_name}_text"] = self.tokenizer.decode(
generation,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
results.append(record)
return results
@add_end_docstrings(PIPELINE_INIT_ARGS)
......@@ -226,6 +238,23 @@ class TranslationPipeline(Text2TextGenerationPipeline):
# Used in the return key of the pipeline.
return_name = "translation"
src_lang: Optional[str] = None
tgt_lang: Optional[str] = None
def __init__(self, *args, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, **kwargs)
if src_lang is not None:
self.src_lang = src_lang
if tgt_lang is not None:
self.tgt_lang = tgt_lang
if src_lang is None and tgt_lang is None:
# Backward compatibility, direct arguments use is preferred.
task = kwargs.get("task", "")
items = task.split("_")
if task and len(items) == 4:
# translation, XX, to YY
self.src_lang = items[1]
self.tgt_lang = items[3]
def check_inputs(self, input_length: int, min_length: int, max_length: int):
if input_length > 0.9 * max_length:
......@@ -233,8 +262,27 @@ class TranslationPipeline(Text2TextGenerationPipeline):
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
"increasing your max_length manually, e.g. translator('...', max_length=400)"
)
return True
def __call__(self, *args, **kwargs):
def _parse_and_tokenize(self, *args, src_lang, tgt_lang, truncation):
if getattr(self.tokenizer, "_build_translation_inputs", None):
return self.tokenizer._build_translation_inputs(
*args, src_lang=src_lang, tgt_lang=tgt_lang, truncation=truncation
)
else:
return super()._parse_and_tokenize(*args, truncation=truncation)
def __call__(
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
src_lang=None,
tgt_lang=None,
**generate_kwargs
):
r"""
Translate the text(s) given as inputs.
......@@ -247,6 +295,12 @@ class TranslationPipeline(Text2TextGenerationPipeline):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output.
src_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
The language of the input. Might be required for multilingual models. Will not have any effect for
single pair translation models
tgt_lang (:obj:`str`, `optional`, defaults to :obj:`None`):
The language of the desired output. Might be required for multilingual models. Will not have any effect
for single pair translation models
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__).
......@@ -258,4 +312,10 @@ class TranslationPipeline(Text2TextGenerationPipeline):
- **translation_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
-- The token ids of the translation.
"""
return super().__call__(*args, **kwargs)
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
src_lang = src_lang if src_lang is not None else self.src_lang
tgt_lang = tgt_lang if tgt_lang is not None else self.tgt_lang
with self.device_placement():
inputs = self._parse_and_tokenize(*args, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
return self._generate(inputs, return_tensors, return_text, clean_up_tokenization_spaces, generate_kwargs)
......@@ -361,6 +361,9 @@ if is_torch_available():
else:
torch_device = None
if is_tf_available():
import tensorflow as tf
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch. """
......@@ -1174,3 +1177,26 @@ def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False
raise RuntimeError(f"'{cmd_str}' produced no output.")
return result
def nested_simplify(obj, decimals=3):
"""
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
within tests.
"""
from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, (dict, BatchEncoding)):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int)):
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist())
elif is_tf_available() and tf.is_tensor(obj):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
......@@ -17,11 +17,15 @@ import unittest
import pytest
from transformers import pipeline
from transformers.testing_utils import is_pipeline_test, require_torch, slow
from transformers.testing_utils import is_pipeline_test, is_torch_available, require_torch, slow
from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "translation_en_to_de"
small_models = ["patrickvonplaten/t5-tiny-random"] # Default model - Models tested without the @slow decorator
......@@ -48,12 +52,38 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
pipeline(task="translation_cn_to_ar")
# but we do for this one
pipeline(task="translation_en_to_de")
translator = pipeline(task="translation_en_to_de")
self.assertEquals(translator.src_lang, "en")
self.assertEquals(translator.tgt_lang, "de")
@require_torch
@slow
def test_multilingual_translation(self):
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
translator = pipeline(task="translation", model=model, tokenizer=tokenizer)
# Missing src_lang, tgt_lang
with self.assertRaises(ValueError):
translator("This is a test")
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])
# src_lang, tgt_lang can be defined at pipeline call time
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
outputs = translator("This is a test")
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])
@require_torch
def test_translation_on_odd_language(self):
model = "patrickvonplaten/t5-tiny-random"
pipeline(task="translation_cn_to_ar", model=model)
translator = pipeline(task="translation_cn_to_ar", model=model)
self.assertEquals(translator.src_lang, "cn")
self.assertEquals(translator.tgt_lang, "ar")
@require_torch
def test_translation_default_language_selection(self):
......@@ -61,6 +91,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
nlp = pipeline(task="translation", model=model)
self.assertEqual(nlp.task, "translation_en_to_de")
self.assertEquals(nlp.src_lang, "en")
self.assertEquals(nlp.tgt_lang, "de")
@require_torch
def test_translation_with_no_language_no_model_fails(self):
......
......@@ -20,7 +20,7 @@ from shutil import copyfile
from transformers import M2M100Tokenizer, is_torch_available
from transformers.file_utils import is_sentencepiece_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
if is_sentencepiece_available():
......@@ -191,3 +191,18 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en", tgt_lang="ar")
self.assertEqual(
nested_simplify(inputs),
{
# en_XX, A, test, EOS
"input_ids": [[128022, 58, 4183, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 128006,
},
)
......@@ -17,7 +17,7 @@ import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
......@@ -232,3 +232,18 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(
nested_simplify(inputs),
{
# A, test, EOS, en_XX
"input_ids": [[62, 3034, 2, 250004]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 250001,
},
)
......@@ -17,7 +17,7 @@ import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBart50Tokenizer, MBart50TokenizerFast, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from transformers.testing_utils import nested_simplify, require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
......@@ -194,3 +194,18 @@ class MBartOneToManyIntegrationTest(unittest.TestCase):
self.assertEqual(batch.input_ids.shape[1], 3)
self.assertEqual(batch.decoder_input_ids.shape[1], 10)
@require_torch
def test_tokenizer_translation(self):
inputs = self.tokenizer._build_translation_inputs("A test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(
nested_simplify(inputs),
{
# en_XX, A, test, EOS
"input_ids": [[250004, 62, 3034, 2]],
"attention_mask": [[1, 1, 1, 1]],
# ar_AR
"forced_bos_token_id": 250001,
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment