Unverified Commit 9a687ebb authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Marian Fixes] prevent predicting pad_token_id before softmax, support...

[Marian Fixes] prevent predicting pad_token_id before softmax, support language codes, name multilingual models (#4290)
parent 839bfaed
MarianMTModel MarianMT
---------------------------------------------------- ----------------------------------------------------
**DISCLAIMER:** If you see something strange, **DISCLAIMER:** If you see something strange,
file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign file a `Github Issue <https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__ and assign
@sshleifer @sshleifer. Translations should be similar, but not identical to, output in the test set linked to in each model card.
These models are for machine translation. The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
Opus Project
~~~~~~~~~~~~
The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
Implementation Notes Implementation Notes
~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~
- each model is about 298 MB on disk, there are 1,000+ models. - each model is about 298 MB on disk, there are 1,000+ models.
- Models are named with the following patter 'Helsinki-NLP/opus-mt-{src_langs}-{targ_langs}'. If there are multiple source or target languages they are joined by a '+' symbol. - The list of supported language pairs can be found `here <https://huggingface.co/Helsinki-NLP>`__.
- The 1,000+ models were originally trained by `Jörg Tiedemann <https://researchportal.helsinki.fi/en/persons/j%C3%B6rg-tiedemann>`__ using the `Marian <https://marian-nmt.github.io/>`_ C++ library, which supports fast training and translation.
- All models are transformer encoder-decoders with 6 layers in each component. Each model's performance is documented in a model card.
- the 80 opus models that require BPE preprocessing are not supported. - the 80 opus models that require BPE preprocessing are not supported.
- There is an outstanding issue w.r.t multilingual models and language codes. - The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
- The modeling code is the same as ``BartModel`` with a few minor modifications:
- static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``) - static (sinusoid) positional embeddings (``MarianConfig.static_position_embeddings=True``)
- a new final_logits_bias (``MarianConfig.add_bias_logits=True``) - a new final_logits_bias (``MarianConfig.add_bias_logits=True``)
- no layernorm_embedding (``MarianConfig.normalize_embedding=False``) - no layernorm_embedding (``MarianConfig.normalize_embedding=False``)
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>) - the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix. (Bart uses <s/>)
- Code to bulk convert models can be found in ``convert_marian_to_pytorch.py`` - Code to bulk convert models can be found in ``convert_marian_to_pytorch.py``
Naming
~~~~~~
- All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``
- The language codes used to name models are inconsistent. Two digit codes can usually be found `here <https://developers.google.com/admin-sdk/directory/v1/languages>`_, three digit codes require googling "language code {code}".
- Codes formatted like ``es_AR`` are usually ``code_{region}``. That one is spanish documents from Argentina.
Multilingual Models
~~~~~~~~~~~~~~~~~~~~
All model names use the following format: ``Helsinki-NLP/opus-mt-{src}-{tgt}``:
- if ``src`` is in all caps, the model supports multiple input languages, you can figure out which ones by looking at the model card, or the Group Members `mapping <https://gist.github.com/sshleifer/6d20e7761931b08e73c3219027b97b8a>`_ .
- if ``tgt`` is in all caps, the model can output multiple languages, and you should specify a language code by prepending the desired output language to the src_text
- You can see a tokenizer's supported language codes in ``tokenizer.supported_language_codes``
Example of translating english to many romance languages, using language codes:
.. code-block:: python
from transformers import MarianMTModel, MarianTokenizer
src_text = [
'>>fr<< this is a sentence in english that we want to translate to french',
'>>pt<< This should go to portuguese',
'>>es<< And this to Spanish'
]
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
tokenizer = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
model = MarianMTModel.from_pretrained(model_name)
translated = model.generate(**tokenizer.prepare_translation_batch(src_text))
tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
# ["c'est une phrase en anglais que nous voulons traduire en français",
# 'Isto deve ir para o português.',
# 'Y esto al español']
Sometimes, models were trained on collections of languages that do not resolve to a group. In this case, _ is used as a separator for src or tgt, as in ``'Helsinki-NLP/opus-mt-en_el_es_fi-en_el_es_fi'``. These still require language codes.
There are many supported regional language codes, like ``>>es_ES<<`` (Spain) and ``>>es_AR<<`` (Argentina), that do not seem to change translations. I have not found these to provide different results than just using ``>>es<<``.
For Example:
- ``Helsinki-NLP/opus-mt-NORTH_EU-NORTH_EU``: translates from all NORTH_EU languages (see `mapping <https://gist.github.com/sshleifer/6d20e7761931b08e73c3219027b97b8a>`_) to all NORTH_EU languages. Use a special language code like ``>>de<<`` to specify output language.
- ``Helsinki-NLP/opus-mt-ROMANCE-en``: translates from many romance languages to english, no codes needed since there is only 1 tgt language.
.. code-block:: python
GROUP_MEMBERS = {
'ZH': ['cmn', 'cn', 'yue', 'ze_zh', 'zh_cn', 'zh_CN', 'zh_HK', 'zh_tw', 'zh_TW', 'zh_yue', 'zhs', 'zht', 'zh'],
'ROMANCE': ['fr', 'fr_BE', 'fr_CA', 'fr_FR', 'wa', 'frp', 'oc', 'ca', 'rm', 'lld', 'fur', 'lij', 'lmo', 'es', 'es_AR', 'es_CL', 'es_CO', 'es_CR', 'es_DO', 'es_EC', 'es_ES', 'es_GT', 'es_HN', 'es_MX', 'es_NI', 'es_PA', 'es_PE', 'es_PR', 'es_SV', 'es_UY', 'es_VE', 'pt', 'pt_br', 'pt_BR', 'pt_PT', 'gl', 'lad', 'an', 'mwl', 'it', 'it_IT', 'co', 'nap', 'scn', 'vec', 'sc', 'ro', 'la'],
'NORTH_EU': ['de', 'nl', 'fy', 'af', 'da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
'SCANDINAVIA': ['da', 'fo', 'is', 'no', 'nb', 'nn', 'sv'],
'SAMI': ['se', 'sma', 'smj', 'smn', 'sms'],
'NORWAY': ['nb_NO', 'nb', 'nn_NO', 'nn', 'nog', 'no_nb', 'no'],
'CELTIC': ['ga', 'cy', 'br', 'gd', 'kw', 'gv']
}
Code to see available pretrained models:
.. code-block:: python
from transformers.hf_api import HfApi
model_list = HfApi().model_list()
org = "Helsinki-NLP"
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
suffix = [x.split('/')[1] for x in model_ids]
multi_models = [f'{org}/{s}' for s in suffix if s != s.lower()]
MarianMTModel MarianMTModel
~~~~~~~~~~~~~ ~~~~~~~~~~~~~
......
...@@ -95,6 +95,97 @@ def find_model_file(dest_dir): # this one better ...@@ -95,6 +95,97 @@ def find_model_file(dest_dir): # this one better
return model_file return model_file
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
GROUPS = [
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
(ROM_GROUP, "ROMANCE"),
("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"),
("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"),
("se+sma+smj+smn+sms", "SAMI"),
("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"),
("ga+cy+br+gd+kw+gv", "CELTIC"), # https://en.wikipedia.org/wiki/Insular_Celtic_languages
]
GROUP_TO_OPUS_NAME = {
"opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de",
"opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv",
"opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv",
"opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
"opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
"opus-mt-en-ROMANCE": "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la",
"opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
"opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
"opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
"opus-mt-ROMANCE-en": "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
"+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
"+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en",
"opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
"opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
"opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
}
OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/"
ORG_NAME = "Helsinki-NLP/"
def convert_opus_name_to_hf_name(x):
for substr, grp_name in GROUPS:
x = x.replace(substr, grp_name)
return x.replace("+", "_")
def convert_hf_name_to_opus_name(hf_model_name):
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
if hf_model_name in GROUP_TO_OPUS_NAME:
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
else:
opus_w_prefix = hf_model_name.replace("_", "+")
return remove_prefix(opus_w_prefix, "opus-mt-")
def write_model_card(
hf_model_name: str,
repo_path="OPUS-MT-train/models/",
dry_run=False,
model_card_dir=Path("marian_converted/model_cards/Helsinki-NLP/"),
) -> str:
"""Copy the most recent model's readme section from opus, and add metadata.
upload command: s3cmd sync --recursive model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
"""
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
readme_url = OPUS_GITHUB_URL + f"{opus_name}/README.md"
s, t = ",".join(opus_src), ",".join(opus_tgt)
extra_markdown = f"### {hf_model_name}\n\n* source languages: {s}\n* target languages: {t}\n* OPUS readme: [{opus_name}]({readme_url})\n"
# combine with opus markdown
opus_readme_path = Path(f"{repo_path}{opus_name}/README.md")
assert opus_readme_path.exists(), opus_readme_path
content = opus_readme_path.open().read()
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
content = "*".join(content.split("*")[1:])
content = extra_markdown + "\n* " + content.replace("download", "download original weights")
if dry_run:
return content
# Save string to model_cards/hf_model_name/readme.md
model_card_dir.mkdir(exist_ok=True)
sub_dir = model_card_dir / hf_model_name
sub_dir.mkdir(exist_ok=True)
dest = sub_dir / "README.md"
dest.open("w").write(content)
return content
def get_clean_model_id_mapping(multiling_model_ids):
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
def make_registry(repo_path="Opus-MT-train/models"): def make_registry(repo_path="Opus-MT-train/models"):
if not (Path(repo_path) / "fr-en" / "README.md").exists(): if not (Path(repo_path) / "fr-en" / "README.md").exists():
raise ValueError( raise ValueError(
...@@ -109,10 +200,7 @@ def make_registry(repo_path="Opus-MT-train/models"): ...@@ -109,10 +200,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
else: else:
lns = list(open(p / "README.md").readlines()) lns = list(open(p / "README.md").readlines())
results[p.name] = _parse_readme(lns) results[p.name] = _parse_readme(lns)
return [(k, v["pre-processing"], v["download"]) for k, v in results.items()] return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
CH_GROUP = "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
def convert_all_sentencepiece_models(model_list=None, repo_path=None): def convert_all_sentencepiece_models(model_list=None, repo_path=None):
...@@ -122,12 +210,12 @@ def convert_all_sentencepiece_models(model_list=None, repo_path=None): ...@@ -122,12 +210,12 @@ def convert_all_sentencepiece_models(model_list=None, repo_path=None):
dest_dir.mkdir(exist_ok=True) dest_dir.mkdir(exist_ok=True)
if model_list is None: if model_list is None:
model_list: list = make_registry(repo_path=repo_path) model_list: list = make_registry(repo_path=repo_path)
for k, prepro, download in tqdm(model_list): for k, prepro, download, test_set_url in tqdm(model_list):
if "SentencePiece" not in prepro: # dont convert BPE models. if "SentencePiece" not in prepro: # dont convert BPE models.
continue continue
if not os.path.exists(save_dir / k / "pytorch_model.bin"): if not os.path.exists(save_dir / k / "pytorch_model.bin"):
download_and_unzip(download, save_dir / k) download_and_unzip(download, save_dir / k)
pair_name = k.replace(CH_GROUP, "ch_group") pair_name = convert_opus_name_to_hf_name(k)
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}") convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
...@@ -135,12 +223,10 @@ def lmap(f, x) -> List: ...@@ -135,12 +223,10 @@ def lmap(f, x) -> List:
return list(map(f, x)) return list(map(f, x))
def fetch_test_set(readmes_raw, pair): def fetch_test_set(test_set_url):
import wget import wget
download_url = readmes_raw[pair]["download"] fname = wget.download(test_set_url, f"opus_test.txt")
test_set_url = download_url[:-4] + ".test.txt"
fname = wget.download(test_set_url, f"opus_test_{pair}.txt")
lns = Path(fname).open().readlines() lns = Path(fname).open().readlines()
src = lmap(str.strip, lns[::4]) src = lmap(str.strip, lns[::4])
gold = lmap(str.strip, lns[1::4]) gold = lmap(str.strip, lns[1::4])
......
...@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -980,12 +980,12 @@ class BartForConditionalGeneration(PretrainedBartModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_logits_for_generation(self, logits, cur_len, max_length):
if cur_len == 1: if cur_len == 1:
self._force_token_ids_generation(scores, self.config.bos_token_id) self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id) self._force_token_ids_generation(logits, self.config.eos_token_id)
return scores return logits
def _force_token_ids_generation(self, scores, token_ids) -> None: def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0""" """force one of token_ids to be generated by setting prob of all other tokens to 0"""
......
...@@ -31,7 +31,7 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -31,7 +31,7 @@ class MarianMTModel(BartForConditionalGeneration):
src = 'fr' # source language src = 'fr' # source language
trg = 'en' # target language trg = 'en' # target language
sample_text = "où est l'arrêt de bus ?" sample_text = "où est l'arrêt de bus ?"
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' # `Model List`__ mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
model = MarianMTModel.from_pretrained(mname) model = MarianMTModel.from_pretrained(mname)
tok = MarianTokenizer.from_pretrained(mname) tok = MarianTokenizer.from_pretrained(mname)
...@@ -43,7 +43,8 @@ class MarianMTModel(BartForConditionalGeneration): ...@@ -43,7 +43,8 @@ class MarianMTModel(BartForConditionalGeneration):
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_logits_for_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf")
if cur_len == max_length - 1 and self.config.eos_token_id is not None: if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(scores, self.config.eos_token_id) self._force_token_ids_generation(logits, self.config.eos_token_id)
return scores return logits
...@@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -744,8 +744,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_inputs_for_generation(self, input_ids, **kwargs): def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids} return {"input_ids": input_ids}
def prepare_scores_for_generation(self, scores, **kwargs): def prepare_logits_for_generation(self, logits, **kwargs):
return scores return logits
def _use_cache(self, outputs, use_cache): def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass.""" """During generation, decide whether to pass the `past` variable to the next forward pass."""
...@@ -857,7 +857,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -857,7 +857,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
Defaults to `None`. Defaults to `None`.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
decoder_start_token_id=None: (`optional`) int decoder_start_token_id=None: (`optional`) int
If an encoder-decoder model starts decoding with a different token than BOS. If an encoder-decoder model starts decoding with a different token than BOS.
...@@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1342,10 +1342,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if temperature != 1.0: if temperature != 1.0:
next_token_logits = next_token_logits / temperature next_token_logits = next_token_logits / temperature
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False: if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solutino # TODO (PVP) still a bit hacky here - there might be a better solution
scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length) next_token_logits = self.prepare_logits_for_generation(
next_token_logits, cur_len=cur_len, max_length=max_length
)
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
if eos_token_id is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
......
import json import json
import re
import warnings import warnings
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -14,7 +15,7 @@ vocab_files_names = { ...@@ -14,7 +15,7 @@ vocab_files_names = {
"vocab": "vocab.json", "vocab": "vocab.json",
"tokenizer_config_file": "tokenizer_config.json", "tokenizer_config_file": "tokenizer_config.json",
} }
MODEL_NAMES = ("opus-mt-en-de",) MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES} k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
for k, fname in vocab_files_names.items() for k, fname in vocab_files_names.items()
...@@ -41,6 +42,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -41,6 +42,7 @@ class MarianTokenizer(PreTrainedTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = {m: 512 for m in MODEL_NAMES} max_model_input_sizes = {m: 512 for m in MODEL_NAMES}
model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask model_input_names = ["attention_mask"] # actually attention_mask, decoder_attention_mask
language_code_re = re.compile(">>.+<<") # type: re.Pattern
def __init__( def __init__(
self, self,
...@@ -72,8 +74,6 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -72,8 +74,6 @@ class MarianTokenizer(PreTrainedTokenizer):
self.target_lang = target_lang self.target_lang = target_lang
# load SentencePiece model for pre-processing # load SentencePiece model for pre-processing
self.paths = {}
self.spm_source = sentencepiece.SentencePieceProcessor() self.spm_source = sentencepiece.SentencePieceProcessor()
self.spm_source.Load(source_spm) self.spm_source.Load(source_spm)
...@@ -82,9 +82,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -82,9 +82,7 @@ class MarianTokenizer(PreTrainedTokenizer):
# Multilingual target side: default to using first supported language code. # Multilingual target side: default to using first supported language code.
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")] self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
self.tgt_lang_id = None # will not be used unless it is set through prepare_translation_batch
# Note(SS): sentence_splitter would require lots of book-keeping.
try: try:
from mosestokenizer import MosesPunctuationNormalizer from mosestokenizer import MosesPunctuationNormalizer
...@@ -93,11 +91,23 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -93,11 +91,23 @@ class MarianTokenizer(PreTrainedTokenizer):
warnings.warn("Recommended: pip install mosestokenizer") warnings.warn("Recommended: pip install mosestokenizer")
self.punc_normalizer = lambda x: x self.punc_normalizer = lambda x: x
def normalize(self, x: str) -> str:
"""Cover moses empty string edge case. They return empty list for '' input!"""
return self.punc_normalizer(x) if x else ""
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
return self.encoder.get(token, self.encoder[self.unk_token]) return self.encoder.get(token, self.encoder[self.unk_token])
def remove_language_code(self, text: str):
"""Remove language codes like <<fr>> before sentencepiece"""
match = self.language_code_re.match(text)
code: list = [match.group(0)] if match else []
return code, self.language_code_re.sub("", text)
def _tokenize(self, text: str) -> List[str]: def _tokenize(self, text: str) -> List[str]:
return self.current_spm.EncodeAsPieces(text) code, text = self.remove_language_code(text)
pieces = self.current_spm.EncodeAsPieces(text)
return code + pieces
def _convert_id_to_token(self, index: int) -> str: def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the encoder.""" """Converts an index (integer) in a token (str) using the encoder."""
...@@ -125,7 +135,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -125,7 +135,7 @@ class MarianTokenizer(PreTrainedTokenizer):
pad_to_max_length: bool = True, pad_to_max_length: bool = True,
return_tensors: str = "pt", return_tensors: str = "pt",
) -> BatchEncoding: ) -> BatchEncoding:
""" """Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments: Arguments:
src_texts: list of src language texts src_texts: list of src language texts
tgt_texts: list of tgt language texts tgt_texts: list of tgt language texts
...@@ -138,7 +148,10 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -138,7 +148,10 @@ class MarianTokenizer(PreTrainedTokenizer):
all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists). all shaped bs, seq_len. (BatchEncoding is a dict of string -> tensor or lists).
If no tgt_text is specified, the only keys will be input_ids and attention_mask. If no tgt_text is specified, the only keys will be input_ids and attention_mask.
""" """
if "" in src_texts:
raise ValueError(f"found empty string in src_texts: {src_texts}")
self.current_spm = self.spm_source self.current_spm = self.spm_source
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
model_inputs: BatchEncoding = self.batch_encode_plus( model_inputs: BatchEncoding = self.batch_encode_plus(
src_texts, src_texts,
add_special_tokens=True, add_special_tokens=True,
......
...@@ -33,15 +33,21 @@ if is_torch_available(): ...@@ -33,15 +33,21 @@ if is_torch_available():
MarianTokenizer, MarianTokenizer,
MarianMTModel, MarianMTModel,
) )
from transformers.convert_marian_to_pytorch import (
convert_hf_name_to_opus_name,
convert_opus_name_to_hf_name,
ORG_NAME,
)
class ModelManagementTests(unittest.TestCase): class ModelManagementTests(unittest.TestCase):
@slow @slow
def test_model_count(self): def test_model_names(self):
model_list = HfApi().model_list() model_list = HfApi().model_list()
expected_num_models = 1011 model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
actual_num_models = len([x for x in model_list if x.modelId.startswith("Helsinki-NLP")]) bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
self.assertEqual(expected_num_models, actual_num_models) self.assertListEqual([], bad_model_ids)
self.assertGreater(len(model_ids), 500)
@require_torch @require_torch
...@@ -91,12 +97,12 @@ class MarianIntegrationTest(unittest.TestCase): ...@@ -91,12 +97,12 @@ class MarianIntegrationTest(unittest.TestCase):
self.assertListEqual(self.expected_text, generated_words) self.assertListEqual(self.expected_text, generated_words)
def translate_src_text(self, **tokenizer_kwargs): def translate_src_text(self, **tokenizer_kwargs):
model_inputs: dict = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to( model_inputs = self.tokenizer.prepare_translation_batch(src_texts=self.src_text, **tokenizer_kwargs).to(
torch_device torch_device
) )
self.assertEqual(self.model.device, model_inputs["input_ids"].device) self.assertEqual(self.model.device, model_inputs.input_ids.device)
generated_ids = self.model.generate( generated_ids = self.model.generate(
model_inputs["input_ids"], attention_mask=model_inputs["attention_mask"], num_beams=2 model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
) )
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_words return generated_words
...@@ -106,10 +112,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -106,10 +112,10 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
@slow @slow
def test_forward(self): def test_forward(self):
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."] src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
expected = [38, 121, 14, 697, 38848, 0] expected_ids = [38, 121, 14, 697, 38848, 0]
model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device) model_inputs: dict = self.tokenizer.prepare_translation_batch(src, tgt_texts=tgt).to(torch_device)
self.assertListEqual(expected, model_inputs["input_ids"][0].tolist()) self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
desired_keys = { desired_keys = {
"input_ids", "input_ids",
...@@ -125,20 +131,19 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -125,20 +131,19 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
def test_tokenizer_equivalence(self): def test_tokenizer_equivalence(self):
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device) batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
input_ids = batch["input_ids"][0]
expected = [38, 121, 14, 697, 38848, 0] expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, input_ids.tolist()) self.assertListEqual(expected, batch.input_ids[0].tolist())
def test_unk_support(self): def test_unk_support(self):
t = self.tokenizer t = self.tokenizer
ids = t.prepare_translation_batch(["||"]).to(torch_device)["input_ids"][0].tolist() ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id] expected = [t.unk_token_id, t.unk_token_id, t.eos_token_id]
self.assertEqual(expected, ids) self.assertEqual(expected, ids)
def test_pad_not_split(self): def test_pad_not_split(self):
input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"])["input_ids"][0] input_ids_w_pad = self.tokenizer.prepare_translation_batch(["I am a small frog <pad>"]).input_ids[0].tolist()
expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad expected_w_pad = [38, 121, 14, 697, 38848, self.tokenizer.pad_token_id, 0] # pad
self.assertListEqual(expected_w_pad, input_ids_w_pad.tolist()) self.assertListEqual(expected_w_pad, input_ids_w_pad)
@slow @slow
def test_batch_generation_en_de(self): def test_batch_generation_en_de(self):
...@@ -187,9 +192,8 @@ class TestMarian_RU_FR(MarianIntegrationTest): ...@@ -187,9 +192,8 @@ class TestMarian_RU_FR(MarianIntegrationTest):
src = "ru" src = "ru"
tgt = "fr" tgt = "fr"
src_text = ["Он показал мне рукопись своей новой пьесы."] src_text = ["Он показал мне рукопись своей новой пьесы."]
expected_text = ["Il me montre un manuscrit de sa nouvelle pièce."] expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
@slow
def test_batch_generation_ru_fr(self): def test_batch_generation_ru_fr(self):
self._assert_generated_batch_equal_expected() self._assert_generated_batch_equal_expected()
...@@ -197,36 +201,59 @@ class TestMarian_RU_FR(MarianIntegrationTest): ...@@ -197,36 +201,59 @@ class TestMarian_RU_FR(MarianIntegrationTest):
class TestMarian_MT_EN(MarianIntegrationTest): class TestMarian_MT_EN(MarianIntegrationTest):
src = "mt" src = "mt"
tgt = "en" tgt = "en"
src_text = ["Il - Babiloniżi b'mod żbaljat ikkonkludew li l - Alla l - veru kien dgħajjef."] src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
expected_text = ["The Babylonians wrongly concluded that the true God was weak."] expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
@unittest.skip("") # Known Issue: This model generates a string of .... at the end of the translation.
def test_batch_generation_mt_en(self): def test_batch_generation_mt_en(self):
self._assert_generated_batch_equal_expected() self._assert_generated_batch_equal_expected()
class TestMarian_DE_Multi(MarianIntegrationTest): class TestMarian_en_ROMANCE(MarianIntegrationTest):
src = "de" """Multilingual on target side."""
tgt = "ch_group"
src_text = ["Er aber sprach: Das ist die Gottlosigkeit."]
@slow src = "en"
def test_translation_de_multi_does_not_error(self): tgt = "ROMANCE"
self.translate_src_text() src_text = [
">>fr<< Don't spend so much time watching TV.",
">>pt<< Your message has been sent.",
">>es<< He's two years older than me.",
]
expected_text = [
"Ne passez pas autant de temps à regarder la télé.",
"A sua mensagem foi enviada.",
"Es dos años más viejo que yo.",
]
@unittest.skip("") # "Language codes are not yet supported." @slow
def test_batch_generation_de_multi_tgt(self): def test_batch_generation_en_ROMANCE_multi(self):
self._assert_generated_batch_equal_expected() self._assert_generated_batch_equal_expected()
@unittest.skip("") # "Language codes are not yet supported." def test_tokenizer_handles_empty(self):
def test_lang_code(self): normalized = self.tokenizer.normalize("")
t = "Er aber sprach" self.assertIsInstance(normalized, str)
zh_code = self.code with self.assertRaises(ValueError):
tok_fn = self.tokenizer.prepare_translation_batch self.tokenizer.prepare_translation_batch([""])
pass_code = tok_fn(src_texts=[t], tgt_lang_code=zh_code)["input_ids"][0]
preprocess_with_code = tok_fn(src_texts=[zh_code + " " + t])["input_ids"][0]
self.assertListEqual(pass_code.tolist(), preprocess_with_code.tolist()) @require_torch
for code in self.tokenizer.supported_language_codes: class TestConversionUtils(unittest.TestCase):
self.assertIn(code, self.tokenizer.encoder) def test_renaming_multilingual(self):
pass_only_code = tok_fn(src_texts=[""], tgt_lang_code=zh_code)["input_ids"][0].tolist() old_names = [
self.assertListEqual(pass_only_code, [self.tokenizer.encoder[zh_code], self.tokenizer.eos_token_id]) "opus-mt-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"opus-mt-cmn+cn-fi", # no group
"opus-mt-en-de", # standard name
"opus-mt-en-de", # standard name
]
expected = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]
self.assertListEqual(expected, [convert_opus_name_to_hf_name(x) for x in old_names])
def test_undoing_renaming(self):
hf_names = ["opus-mt-ZH-fi", "opus-mt-cmn_cn-fi", "opus-mt-en-de", "opus-mt-en-de"]
converted_opus_names = [convert_hf_name_to_opus_name(x) for x in hf_names]
expected_opus_names = [
"cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
"cmn+cn-fi",
"en-de", # standard name
"en-de",
]
self.assertListEqual(expected_opus_names, converted_opus_names)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment