Unverified Commit 9a687ebb authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Marian Fixes] prevent predicting pad_token_id before softmax, support...

[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)
parent 839bfaed
MarianMTModel
MarianMT
----------------------------------------------------
**DISCLAIMER:** If you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@sshleifer
These models are for machine translation. The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
Opus Project
~~~~~~~~~~~~
The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
@sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card.
Implementation Notes
~~~~~~~~~~~~~~~~~~~~
- each model is about 298 MB on disk, there are 1,000+ models.
- Models are named with the following patter 'Helsinki-NLP/opus-mt-{src_langs}-{targ_langs}'. If there are multiple source or target languages they are joined by a '+' symbol.
- The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
- The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
- All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
- the 80 opus models that require BPE preprocessing are not supported.
- There is an outstanding issue w.r.t multilingual models and language codes.
- The modeling code is the same as ``BartModel`` with a few minor modifications:
- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
- static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``)
- a new final_logits_bias (``MarianConfig.add_bias_logits=True``)
- no layernorm_embedding (``MarianConfig.normalize_embedding=False``)
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>)
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``
Naming
~~~~~~
- All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``
- The language codes used to name models are inconsistent. Two digit codes can usually be found `here <https://developers.google.com/admin-sdk/directory/v1/languages>`_, three digit codes require googling "language code {code}".
- Codes formatted like ``es_AR`` are usually ``code_{region}``. That one is spanish documents from Argentina.
Multilingual Models
~~~~~~~~~~~~~~~~~~~~
All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``:
- if ``src`` is in all caps, the model supports multiple input languages, you can figure out which ones by looking at the model card, or the Group Members `mapping <https://gist.github.com/sshleifer/6d20e7761931b08e73c3219027b97b8a>`_ .
- if ``tgt`` is in all caps, the model can output multiple languages, and you should specify a language code by prepending the desired output language to the src_text
- You can see a tokenizer's supported language codes in ``tokenizer.supported_language_codes``
Example of translating english to many romance languages, using language codes:
.. code-block:: python
from transformers import MarianMTModel, MarianTokenizer
src_text = [
'>>fr<< this is a sentence in english that we want to translate to french',
'>>pt<< This should go to portuguese',
'>>es<< And this to Spanish'
]
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_translation_batch(src_text))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o português.',
# 'Y esto al español']
Sometimes, models were trained on collections of languages that do not resolve to a group. In this case, _ is used as a separator for src or tgt, as in ``'Helsinki-NLP/opus-mt-en_el_es_fi-en_el_es_fi'``. These still require language codes.
There are many supported regional language codes, like ``>>es_ES<<`` (Spain) and ``>>es_AR<<`` (Argentina), that do not seem to change translations. I have not found these to provide different results than just using ``>>es<<``.
For Example:
- ``Helsinki-NLP/opus-mt-NORTH_EU-NORTH_EU``: translates from all NORTH_EU languages (see `mapping <https://gist.github.com/sshleifer/6d20e7761931b08e73c3219027b97b8a>`_) to all NORTH_EU languages. Use a special language code like ``>>de<<`` to specify output language.
- ``Helsinki-NLP/opus-mt-ROMANCE-en``: translates from many romance languages to english, no codes needed since there is only 1 tgt language.
.. code-block:: python
GROUP_MEMBERS = {
'ZH': ['cmn', 'cn', 'yue', 'ze_zh', 'zh_cn', 'zh_CN', 'zh_HK', 'zh_tw', 'zh_TW', 'zh_yue', 'zhs', 'zht', 'zh'],
'ROMANCE': ['fr', 'fr_BE', 'fr_CA', 'fr_FR', 'wa', 'frp', 'oc', 'ca', 'rm', 'lld', 'fur', 'lij', 'lmo', 'es', 'es_AR', 'es_CL', 'es_CO', 'es_CR', 'es_DO', 'es_EC', 'es_ES', 'es_GT', 'es_HN', 'es_MX', 'es_NI', 'es_PA', 'es_PE', 'es_PR', 'es_SV', 'es_UY', 'es_VE', 'pt', 'pt_br', 'pt_BR', 'pt_PT', 'gl', 'lad', 'an', 'mwl', 'it', 'it_IT', 'co', 'nap', 'scn', 'vec', 'sc', 'ro', 'la'],
'NORTH_EU': ['de', 'nl', 'fy', 'af', 'da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
'SCANDINAVIA': ['da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
'SAMI': ['se', 'sma', 'smj', 'smn', 'sms'],
'NORWAY': ['nb_NO', 'nb', 'nn_NO', 'nn', 'nog', 'no_nb', 'no'],
'CELTIC': ['ga', 'cy', 'br', 'gd', 'kw', 'gv']
}
Code to see available pretrained models:
.. code-block:: python
from transformers.hf_api import HfApi
model_list = HfApi().model_list()
org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids]
multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()]
MarianMTModel
~~~~~~~~~~~~~
......
......@@ -95,6 +95,97 @@ def find_model_file(dest_dir): # this one better
return model_file
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
GROUPS = [
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
(ROM_GROUP, "ROMANCE"),
("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"),
("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"),
("se+sma+smj+smn+sms", "SAMI"),
("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"),
("ga+cy+br+gd+kw+gv", "CELTIC"), # https://en.wikipedia.org/wiki/Insular_Celtic_languages
]
GROUP_TO_OPUS_NAME = {
"opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de",
"opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv",
"opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv",
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
"opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la",
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en",
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
}
OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/"
ORG_NAME = "Helsinki-NLP/"
def convert_opus_name_to_hf_name(x):
for substr, grp_name in GROUPS:
x = x.replace(substr, grp_name)
return x.replace("+", "_")
def convert_hf_name_to_opus_name(hf_model_name):
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
if hf_model_name in GROUP_TO_OPUS_NAME:
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
else:
opus_w_prefix = hf_model_name.replace("_", "+")
return remove_prefix(opus_w_prefix, "opus-mt-")
def write_model_card(
hf_model_name: str,
repo_path="OPUS-MT-train/models/",
dry_run=False,
model_card_dir=Path("marian_converted/model_cards/Helsinki-NLP/"),
) -> str:
"""Copy the most recent model's readme section from opus, and add metadata.
upload command: s3cmd sync --recursive model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
"""
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
readme_url = OPUS_GITHUB_URL + f"{opus_name}/README.md"
s, t = ",".join(opus_src), ",".join(opus_tgt)
extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n"
# combine with opus markdown
opus_readme_path = Path(f"{repo_path}{opus_name}/README.md")
assert opus_readme_path.exists(), opus_readme_path
content = opus_readme_path.open().read()
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
content = "*".join(content.split("*")[1:])
content = extra_markdown + "\n* " + content.replace("download", "download original weights")
if dry_run:
return content
# Save string to model_cards/hf_model_name/readme.md
model_card_dir.mkdir(exist_ok=True)
sub_dir = model_card_dir / hf_model_name
sub_dir.mkdir(exist_ok=True)
dest = sub_dir / "README.md"
dest.open("w").write(content)
return content
def get_clean_model_id_mapping(multiling_model_ids):
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
def make_registry(repo_path="Opus-MT-train/models"):
if not (Path(repo_path) / "fr-en" / "README.md").exists():
raise ValueError(
......@@ -109,10 +200,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
else:
lns = list(open(p / "README.md").readlines())
results[p.name] = _parse_readme(lns)
return [(k, v["pre-processing"], v["download"]) for k, v in results.items()]
CH_GROUP = "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
......@@ -122,12 +210,12 @@ def convert_all_sentencepiece_models(model_list=None, repo_path=None):
dest_dir.mkdir(exist_ok=True)
if model_list is None:
model_list: list = make_registry(repo_path=repo_path)
for k, prepro, download in tqdm(model_list):
for k, prepro, download, test_set_url in tqdm(model_list):
if "SentencePiece" not in prepro: # dont convert BPE models.
continue
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
download_and_unzip(download, save_dir / k)
pair_name = k.replace(CH_GROUP, "ch_group")
pair_name = convert_opus_name_to_hf_name(k)
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
......@@ -135,12 +223,10 @@ def lmap(f, x) -> List:
return list(map(f, x))
def fetch_test_set(readmes_raw, pair):
def fetch_test_set(test_set_url):
import wget
download_url = readmes_raw[pair]["download"]
test_set_url = download_url[:-4] + ".test.txt"
fname = wget.download(test_set_url, f"opus_test_{pair}.txt")
fname = wget.download(test_set_url, f"opus_test.txt")
lns = Path(fname).open().readlines()
src = lmap(str.strip, lns[::4])
gold = lmap(str.strip, lns[1::4])
......
......@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
def prepare_logits_for_generation(self, logits, cur_len, max_length):
if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id)
self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
......
......@@ -31,7 +31,7 @@ class MarianMTModel(BartForConditionalGeneration):
src = 'fr' # source language
trg = 'en' # target language
sample_text = "où est l'arrêt de bus ?"
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' # `Model List`__
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
model = MarianMTModel.from_pretrained(mname)
tok = MarianTokenizer.from_pretrained(mname)
......@@ -43,7 +43,8 @@ class MarianMTModel(BartForConditionalGeneration):
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
def prepare_scores_for_generation(self, scores, cur_len, max_length):
def prepare_logits_for_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf")
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id)
return scores
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
......@@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def prepare_scores_for_generation(self, scores, **kwargs):
return scores
def prepare_logits_for_generation(self, logits, **kwargs):
return logits
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
......@@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solutino
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.prepare_logits_for_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length:
......
import json
import re
import warnings
from typing import Dict, List, Optional, Union
......@@ -14,7 +15,7 @@ vocab_files_names = {
"vocab": "vocab.json",
"tokenizer_config_file": "tokenizer_config.json",
}
MODEL_NAMES = ("opus-mt-en-de",)
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names
PRETRAINED_VOCAB_FILES_MAP = {
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
for k, fname in vocab_files_names.items()
......@@ -41,6 +42,7 @@ class MarianTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
language_code_re = re.compile(">>.+<<") # type: re.Pattern
def __init__(
self,
......@@ -72,8 +74,6 @@ class MarianTokenizer(PreTrainedTokenizer):
self.target_lang = target_lang
# load SentencePiece model for pre-processing
self.paths = {}
self.spm_source = sentencepiece.SentencePieceProcessor()
self.spm_source.Load(source_spm)
......@@ -82,9 +82,7 @@ class MarianTokenizer(PreTrainedTokenizer):
# Multilingual target side: default to using first supported language code.
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
self.tgt_lang_id = None # will not be used unless it is set through prepare_translation_batch
# Note(SS): sentence_splitter would require lots of book-keeping.
try:
from mosestokenizer import MosesPunctuationNormalizer
......@@ -93,11 +91,23 @@ class MarianTokenizer(PreTrainedTokenizer):
warnings.warn("Recommended: pip install mosestokenizer")
self.punc_normalizer = lambda x: x
def normalize(self, x: str) -> str:
"""Cover moses empty string edge case. They return empty list for '' input!"""
return self.punc_normalizer(x) if x else ""
def _convert_token_to_id(self, token):
return self.encoder.get(token, self.encoder[self.unk_token])
def remove_language_code(self, text: str):
"""Remove language codes like <<fr>> before sentencepiece"""
match = self.language_code_re.match(text)
code: list = [match.group(0)] if match else []
return code, self.language_code_re.sub("", text)
def _tokenize(self, text: str) -> List[str]:
return self.current_spm.EncodeAsPieces(text)
code, text = self.remove_language_code(text)
pieces = self.current_spm.EncodeAsPieces(text)
return code + pieces
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the encoder."""
......@@ -125,7 +135,7 @@ class MarianTokenizer(PreTrainedTokenizer):
pad_to_max_length: bool = True,
return_tensors: str = "pt",
) -> BatchEncoding:
"""
"""Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments:
src_texts: list of src language texts
tgt_texts: list of tgt language texts
......@@ -138,7 +148,10 @@ class MarianTokenizer(PreTrainedTokenizer):
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
If no tgt_text is specified, the only keys will be input_ids and attention_mask.
"""
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
self.current_spm = self.spm_source
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
model_inputs: BatchEncoding = self.batch_encode_plus(
src_texts,
add_special_tokens=True,
......
......@@ -33,15 +33,21 @@ if is_torch_available():
MarianTokenizer,
MarianMTModel,
)
from transformers.convert_marian_to_pytorch import (
convert_hf_name_to_opus_name,
convert_opus_name_to_hf_name,
ORG_NAME,
)
class ModelManagementTests(unittest.TestCase):
@slow
def test_model_count(self):
def test_model_names(self):
model_list = HfApi().model_list()
expected_num_models = 1011
actual_num_models = len([x for x in model_list if x.modelId.startswith("Helsinki-NLP")])
self.assertEqual(expected_num_models, actual_num_models)
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertListEqual([], bad_model_ids)
self.assertGreater(len(model_ids), 500)
@require_torch
......@@ -91,12 +97,12 @@ class MarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs):
model_inputs: dict = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
torch_device
)
self.assertEqual(self.model.device, model_inputs["input_ids"].device)
self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate(
model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], num_beams=2
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
)
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_words
......@@ -106,10 +112,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
@slow
def test_forward(self):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected = [38, 121, 14, 697, 38848, 0]
expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
self.assertListEqual(expected, model_inputs["input_ids"][0].tolist())
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
desired_keys = {
"input_ids",
......@@ -125,20 +131,19 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
def test_tokenizer_equivalence(self):
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
input_ids = batch["input_ids"][0]
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, input_ids.tolist())
self.assertListEqual(expected, batch.input_ids[0].tolist())
def test_unk_support(self):
t = self.tokenizer
ids = t.prepare_translation_batch(["||"]).to(torch_device)["input_ids"][0].tolist()
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
self.assertEqual(expected, ids)
def test_pad_not_split(self):
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"])["input_ids"][0]
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
self.assertListEqual(expected_w_pad, input_ids_w_pad.tolist())
self.assertListEqual(expected_w_pad, input_ids_w_pad)
@slow
def test_batch_generation_en_de(self):
......@@ -187,9 +192,8 @@ class TestMarian_RU_FR(MarianIntegrationTest):
src = "ru"
tgt = "fr"
src_text = ["Он показал мне рукопись своей новой пьесы."]
expected_text = ["Il me montre un manuscrit de sa nouvelle pièce."]
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
@slow
def test_batch_generation_ru_fr(self):
self._assert_generated_batch_equal_expected()
......@@ -197,36 +201,59 @@ class TestMarian_RU_FR(MarianIntegrationTest):
class TestMarian_MT_EN(MarianIntegrationTest):
src = "mt"
tgt = "en"
src_text = ["Il - Babiloniżi b'mod żbaljat ikkonkludew li l - Alla l - veru kien dgħajjef."]
expected_text = ["The Babylonians wrongly concluded that the true God was weak."]
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
@unittest.skip("") # Known Issue: This model generates a string of .... at the end of the translation.
def test_batch_generation_mt_en(self):
self._assert_generated_batch_equal_expected()
class TestMarian_DE_Multi(MarianIntegrationTest):
src = "de"
tgt = "ch_group"
src_text = ["Er aber sprach: Das ist die Gottlosigkeit."]
class TestMarian_en_ROMANCE(MarianIntegrationTest):
"""Multilingual on target side."""
@slow
def test_translation_de_multi_does_not_error(self):
self.translate_src_text()
src = "en"
tgt = "ROMANCE"
src_text = [
">>fr<< Don't spend so much time watching TV.",
">>pt<< Your message has been sent.",
">>es<< He's two years older than me.",
]
expected_text = [
"Ne passez pas autant de temps à regarder la télé.",
"A sua mensagem foi enviada.",
"Es dos años más viejo que yo.",
]
@unittest.skip("") # "Language codes are not yet supported."
def test_batch_generation_de_multi_tgt(self):
@slow
def test_batch_generation_en_ROMANCE_multi(self):
self._assert_generated_batch_equal_expected()
@unittest.skip("") # "Language codes are not yet supported."
def test_lang_code(self):
t = "Er aber sprach"
zh_code = self.code
tok_fn = self.tokenizer.prepare_translation_batch
pass_code = tok_fn(src_texts=[t], tgt_lang_code=zh_code)["input_ids"][0]
preprocess_with_code = tok_fn(src_texts=[zh_code + " " + t])["input_ids"][0]
self.assertListEqual(pass_code.tolist(), preprocess_with_code.tolist())
for code in self.tokenizer.supported_language_codes:
self.assertIn(code, self.tokenizer.encoder)
pass_only_code = tok_fn(src_texts=[""], tgt_lang_code=zh_code)["input_ids"][0].tolist()
self.assertListEqual(pass_only_code, [self.tokenizer.encoder[zh_code], self.tokenizer.eos_token_id])
def test_tokenizer_handles_empty(self):
normalized = self.tokenizer.normalize("")
self.assertIsInstance(normalized, str)
with self.assertRaises(ValueError):
self.tokenizer.prepare_translation_batch([""])
@require_torch
class TestConversionUtils(unittest.TestCase):
def test_renaming_multilingual(self):
old_names = [
"opus-mt-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"opus-mt-cmn+cn-fi", # no group
"opus-mt-en-de", # standard name
"opus-mt-en-de", # standard name
]
expected = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]
self.assertListEqual(expected, [convert_opus_name_to_hf_name(x) for x in old_names])
def test_undoing_renaming(self):
hf_names = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]
converted_opus_names = [convert_hf_name_to_opus_name(x) for x in hf_names]
expected_opus_names = [
"cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"cmn+cn-fi",
"en-de", # standard name
"en-de",
]
self.assertListEqual(expected_opus_names, converted_opus_names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment