Unverified Commit 9c2b2db2 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[marian] Automate Tatoeba-Challenge conversion (#7709)

parent aacac8f7
......@@ -12,6 +12,7 @@ __pycache__/
tests/fixtures
logs/
lightning_logs/
lang_code_data/
# Distribution / packaging
.Python
......
import tempfile
import unittest
from transformers.convert_marian_tatoeba_to_pytorch import TatoebaConverter
from transformers.file_utils import cached_property
from transformers.testing_utils import slow
class TatoebaConversionTester(unittest.TestCase):
@cached_property
def resolver(self):
tmp_dir = tempfile.mkdtemp()
return TatoebaConverter(save_dir=tmp_dir)
@slow
def test_resolver(self):
self.resolver.convert_models(["heb-eng"])
@slow
def test_model_card(self):
content, mmeta = self.resolver.write_model_card("opus-mt-he-en", dry_run=True)
assert mmeta["long_pair"] == "heb-eng"
Setup transformers following instructions in README.md, (I would fork first).
```bash
git clone git@github.com:huggingface/transformers.git
cd transformers
pip install -e .
pip install pandas
```
Get required metadata
```
curl https://cdn-datasets.huggingface.co/language_codes/language-codes-3b2.csv > language-codes-3b2.csv
curl https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv > iso-639-3.csv
```
Install Tatoeba-Challenge repo inside transformers
```bash
git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git
```
To convert a few models, call the conversion script from command line:
```bash
python src/transformers/convert_marian_tatoeba_to_pytorch.py --models heb-eng eng-heb --save_dir converted
```
To convert lots of models you can pass your list of Tatoeba model names to `resolver.convert_models` in a python client or script.
```python
from transformers.convert_marian_tatoeba_to_pytorch import TatoebaConverter
resolver = TatoebaConverter(save_dir='converted')
resolver.convert_models(['heb-eng', 'eng-heb'])
```
### Upload converted models
```bash
cd converted
transformers-cli login
for FILE in *; do transformers-cli upload $FILE; done
```
### Modifications
- To change naming logic, change the code near `os.rename`. The model card creation code may also need to change.
- To change model card content, you must modify `TatoebaCodeResolver.write_model_card`
import argparse
import os
from pathlib import Path
from typing import List, Tuple
from transformers.convert_marian_to_pytorch import (
FRONT_MATTER_TEMPLATE,
_parse_readme,
convert_all_sentencepiece_models,
get_system_metadata,
remove_prefix,
remove_suffix,
)
try:
import pandas as pd
except ImportError:
pass
DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
ISO_URL = "https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv"
ISO_PATH = "lang_code_data/iso-639-3.csv"
LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv"
class TatoebaConverter:
"""Convert Tatoeba-Challenge models to huggingface format.
Steps:
(1) convert numpy state dict to hf format (same code as OPUS-MT-Train conversion).
(2) rename opus model to huggingface format. This means replace each alpha3 code with an alpha2 code if a unique one existes.
e.g. aav-eng -> aav-en, heb-eng -> he-en
(3) write a model card containing the original Tatoeba-Challenge/README.md and extra info about alpha3 group members.
"""
def __init__(self, save_dir="marian_converted"):
assert Path(DEFAULT_REPO).exists(), "need git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git"
reg = self.make_tatoeba_registry()
self.download_metadata()
self.registry = reg
reg_df = pd.DataFrame(reg, columns=["id", "prepro", "url_model", "url_test_set"])
assert reg_df.id.value_counts().max() == 1
reg_df = reg_df.set_index("id")
reg_df["src"] = reg_df.reset_index().id.apply(lambda x: x.split("-")[0]).values
reg_df["tgt"] = reg_df.reset_index().id.apply(lambda x: x.split("-")[1]).values
released_cols = [
"url_base",
"pair", # (ISO639-3/ISO639-5 codes),
"short_pair", # (reduced codes),
"chrF2_score",
"bleu",
"brevity_penalty",
"ref_len",
"src_name",
"tgt_name",
]
released = pd.read_csv("Tatoeba-Challenge/models/released-models.txt", sep="\t", header=None).iloc[:-1]
released.columns = released_cols
released["fname"] = released["url_base"].apply(
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
)
released["2m"] = released.fname.str.startswith("2m")
released["date"] = pd.to_datetime(
released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-"))
)
released["base_ext"] = released.url_base.apply(lambda x: Path(x).name)
reg_df["base_ext"] = reg_df.url_model.apply(lambda x: Path(x).name)
metadata_new = reg_df.reset_index().merge(released.rename(columns={"pair": "id"}), on=["base_ext", "id"])
metadata_renamer = {"src": "src_alpha3", "tgt": "tgt_alpha3", "id": "long_pair", "date": "train_date"}
metadata_new = metadata_new.rename(columns=metadata_renamer)
metadata_new["src_alpha2"] = metadata_new.short_pair.apply(lambda x: x.split("-")[0])
metadata_new["tgt_alpha2"] = metadata_new.short_pair.apply(lambda x: x.split("-")[1])
DROP_COLS_BOTH = ["url_base", "base_ext", "fname"]
metadata_new = metadata_new.drop(DROP_COLS_BOTH, 1)
metadata_new["prefer_old"] = metadata_new.long_pair.isin([])
self.metadata = metadata_new
assert self.metadata.short_pair.value_counts().max() == 1, "Multiple metadata entries for a short pair"
self.metadata = self.metadata.set_index("short_pair")
# wget.download(LANG_CODE_URL)
mapper = pd.read_csv(LANG_CODE_PATH)
mapper.columns = ["a3", "a2", "ref"]
self.iso_table = pd.read_csv(ISO_PATH, sep="\t").rename(columns=lambda x: x.lower())
more_3_to_2 = self.iso_table.set_index("id").part1.dropna().to_dict()
more_3_to_2.update(mapper.set_index("a3").a2.to_dict())
self.alpha3_to_alpha2 = more_3_to_2
self.model_card_dir = Path(save_dir)
self.constituents = GROUP_MEMBERS
def convert_models(self, tatoeba_ids, dry_run=False):
entries_to_convert = [x for x in self.registry if x[0] in tatoeba_ids]
converted_paths = convert_all_sentencepiece_models(entries_to_convert, dest_dir=self.model_card_dir)
for path in converted_paths:
long_pair = remove_prefix(path.name, "opus-mt-").split("-") # eg. heb-eng
assert len(long_pair) == 2
new_p_src = self.get_two_letter_code(long_pair[0])
new_p_tgt = self.get_two_letter_code(long_pair[1])
hf_model_id = f"opus-mt-{new_p_src}-{new_p_tgt}"
new_path = path.parent.joinpath(hf_model_id) # opus-mt-he-en
os.rename(str(path), str(new_path))
self.write_model_card(hf_model_id, dry_run=dry_run)
def get_two_letter_code(self, three_letter_code):
return self.alpha3_to_alpha2.get(three_letter_code, three_letter_code)
def expand_group_to_two_letter_codes(self, grp_name):
return [self.get_two_letter_code(x) for x in self.constituents[grp_name]]
def get_tags(self, code, ref_name):
if len(code) == 2:
assert "languages" not in ref_name, f"{code}: {ref_name}"
return [code], False
elif "languages" in ref_name or len(self.constituents.get(code, [])) > 1:
group = self.expand_group_to_two_letter_codes(code)
group.append(code)
return group, True
else: # zho-> zh
print(f"Three letter monolingual code: {code}")
return [code], False
def resolve_lang_code(self, r) -> Tuple[List[str], str, str]:
"""R is a row in ported"""
short_pair = r.short_pair
src, tgt = short_pair.split("-")
src_tags, src_multilingual = self.get_tags(src, r.src_name)
assert isinstance(src_tags, list)
tgt_tags, tgt_multilingual = self.get_tags(tgt, r.tgt_name)
assert isinstance(tgt_tags, list)
return dedup(src_tags + tgt_tags), src_multilingual, tgt_multilingual
def write_model_card(
self,
hf_model_id: str,
repo_root=DEFAULT_REPO,
dry_run=False,
) -> str:
"""Copy the most recent model's readme section from opus, and add metadata.
upload command: aws s3 sync model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
"""
short_pair = remove_prefix(hf_model_id, "opus-mt-")
extra_metadata = self.metadata.loc[short_pair].drop("2m")
extra_metadata["short_pair"] = short_pair
lang_tags, src_multilingual, tgt_multilingual = self.resolve_lang_code(extra_metadata)
opus_name = f"{extra_metadata.src_alpha3}-{extra_metadata.tgt_alpha3}"
# opus_name: str = self.convert_hf_name_to_opus_name(hf_model_name)
assert repo_root in ("OPUS-MT-train", "Tatoeba-Challenge")
opus_readme_path = Path(repo_root).joinpath("models", opus_name, "README.md")
assert opus_readme_path.exists(), f"Readme file {opus_readme_path} not found"
opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]
readme_url = f"https://github.com/Helsinki-NLP/{repo_root}/tree/master/models/{opus_name}/README.md"
s, t = ",".join(opus_src), ",".join(opus_tgt)
metadata = {
"hf_name": short_pair,
"source_languages": s,
"target_languages": t,
"opus_readme_url": readme_url,
"original_repo": repo_root,
"tags": ["translation"],
"languages": lang_tags,
}
lang_tags = l2front_matter(lang_tags)
metadata["src_constituents"] = self.constituents[s]
metadata["tgt_constituents"] = self.constituents[t]
metadata["src_multilingual"] = src_multilingual
metadata["tgt_multilingual"] = tgt_multilingual
metadata.update(extra_metadata)
metadata.update(get_system_metadata(repo_root))
# combine with Tatoeba markdown
extra_markdown = f"### {short_pair}\n\n* source group: {metadata['src_name']} \n* target group: {metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
content = opus_readme_path.open().read()
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
splat = content.split("*")[2:]
content = "*".join(splat)
# BETTER FRONT MATTER LOGIC
content = (
FRONT_MATTER_TEMPLATE.format(lang_tags)
+ extra_markdown
+ "\n* "
+ content.replace("download", "download original " "weights")
)
items = "\n\n".join([f"- {k}: {v}" for k, v in metadata.items()])
sec3 = "\n### System Info: \n" + items
content += sec3
if dry_run:
return content, metadata
sub_dir = self.model_card_dir / hf_model_id
sub_dir.mkdir(exist_ok=True)
dest = sub_dir / "README.md"
dest.open("w").write(content)
pd.Series(metadata).to_json(sub_dir / "metadata.json")
return content, metadata
def download_metadata(self):
Path(LANG_CODE_PATH).parent.mkdir(exist_ok=True)
import wget
if not os.path.exists(ISO_PATH):
wget.download(ISO_URL, ISO_PATH)
if not os.path.exists(LANG_CODE_PATH):
wget.download(LANG_CODE_URL, LANG_CODE_PATH)
@staticmethod
def make_tatoeba_registry(repo_path=DEFAULT_MODEL_DIR):
if not (Path(repo_path) / "zho-eng" / "README.md").exists():
raise ValueError(
f"repo_path:{repo_path} does not exist: "
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
)
results = {}
for p in Path(repo_path).iterdir():
if len(p.name) != 7:
continue
lns = list(open(p / "README.md").readlines())
results[p.name] = _parse_readme(lns)
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
GROUP_MEMBERS = {
# three letter code -> (group/language name, {constituents...}
# if this language is on the target side the constituents can be used as target language codes.
# if the language is on the source side they are supported natively without special codes.
"aav": ("Austro-Asiatic languages", {"hoc", "hoc_Latn", "kha", "khm", "khm_Latn", "mnw", "vie", "vie_Hani"}),
"afa": (
"Afro-Asiatic languages",
{
"acm",
"afb",
"amh",
"apc",
"ara",
"arq",
"ary",
"arz",
"hau_Latn",
"heb",
"kab",
"mlt",
"rif_Latn",
"shy_Latn",
"som",
"thv",
"tir",
},
),
"afr": ("Afrikaans", {"afr"}),
"alv": (
"Atlantic-Congo languages",
{
"ewe",
"fuc",
"fuv",
"ibo",
"kin",
"lin",
"lug",
"nya",
"run",
"sag",
"sna",
"swh",
"toi_Latn",
"tso",
"umb",
"wol",
"xho",
"yor",
"zul",
},
),
"ara": ("Arabic", {"afb", "apc", "apc_Latn", "ara", "ara_Latn", "arq", "arq_Latn", "arz"}),
"art": (
"Artificial languages",
{
"afh_Latn",
"avk_Latn",
"dws_Latn",
"epo",
"ido",
"ido_Latn",
"ile_Latn",
"ina_Latn",
"jbo",
"jbo_Cyrl",
"jbo_Latn",
"ldn_Latn",
"lfn_Cyrl",
"lfn_Latn",
"nov_Latn",
"qya",
"qya_Latn",
"sjn_Latn",
"tlh_Latn",
"tzl",
"tzl_Latn",
"vol_Latn",
},
),
"aze": ("Azerbaijani", {"aze_Latn"}),
"bat": ("Baltic languages", {"lit", "lav", "prg_Latn", "ltg", "sgs"}),
"bel": ("Belarusian", {"bel", "bel_Latn"}),
"ben": ("Bengali", {"ben"}),
"bnt": (
"Bantu languages",
{"kin", "lin", "lug", "nya", "run", "sna", "swh", "toi_Latn", "tso", "umb", "xho", "zul"},
),
"bul": ("Bulgarian", {"bul", "bul_Latn"}),
"cat": ("Catalan", {"cat"}),
"cau": ("Caucasian languages", {"abk", "kat", "che", "ady"}),
"ccs": ("South Caucasian languages", {"kat"}),
"ceb": ("Cebuano", {"ceb"}),
"cel": ("Celtic languages", {"gla", "gle", "bre", "cor", "glv", "cym"}),
"ces": ("Czech", {"ces"}),
"cpf": ("Creoles and pidgins, French‑based", {"gcf_Latn", "hat", "mfe"}),
"cpp": (
"Creoles and pidgins, Portuguese-based",
{"zsm_Latn", "ind", "pap", "min", "tmw_Latn", "max_Latn", "zlm_Latn"},
),
"cus": ("Cushitic languages", {"som"}),
"dan": ("Danish", {"dan"}),
"deu": ("German", {"deu"}),
"dra": ("Dravidian languages", {"tam", "kan", "mal", "tel"}),
"ell": ("Modern Greek (1453-)", {"ell"}),
"eng": ("English", {"eng"}),
"epo": ("Esperanto", {"epo"}),
"est": ("Estonian", {"est"}),
"euq": ("Basque (family)", {"eus"}),
"eus": ("Basque", {"eus"}),
"fin": ("Finnish", {"fin"}),
"fiu": (
"Finno-Ugrian languages",
{
"est",
"fin",
"fkv_Latn",
"hun",
"izh",
"kpv",
"krl",
"liv_Latn",
"mdf",
"mhr",
"myv",
"sma",
"sme",
"udm",
"vep",
"vro",
},
),
"fra": ("French", {"fra"}),
"gem": (
"Germanic languages",
{
"afr",
"ang_Latn",
"dan",
"deu",
"eng",
"enm_Latn",
"fao",
"frr",
"fry",
"gos",
"got_Goth",
"gsw",
"isl",
"ksh",
"ltz",
"nds",
"nld",
"nno",
"nob",
"nob_Hebr",
"non_Latn",
"pdc",
"sco",
"stq",
"swe",
"swg",
"yid",
},
),
"gle": ("Irish", {"gle"}),
"glg": ("Galician", {"glg"}),
"gmq": ("North Germanic languages", {"dan", "nob", "nob_Hebr", "swe", "isl", "nno", "non_Latn", "fao"}),
"gmw": (
"West Germanic languages",
{
"afr",
"ang_Latn",
"deu",
"eng",
"enm_Latn",
"frr",
"fry",
"gos",
"gsw",
"ksh",
"ltz",
"nds",
"nld",
"pdc",
"sco",
"stq",
"swg",
"yid",
},
),
"grk": ("Greek languages", {"grc_Grek", "ell"}),
"hbs": ("Serbo-Croatian", {"hrv", "srp_Cyrl", "bos_Latn", "srp_Latn"}),
"heb": ("Hebrew", {"heb"}),
"hin": ("Hindi", {"hin"}),
"hun": ("Hungarian", {"hun"}),
"hye": ("Armenian", {"hye", "hye_Latn"}),
"iir": (
"Indo-Iranian languages",
{
"asm",
"awa",
"ben",
"bho",
"gom",
"guj",
"hif_Latn",
"hin",
"jdt_Cyrl",
"kur_Arab",
"kur_Latn",
"mai",
"mar",
"npi",
"ori",
"oss",
"pan_Guru",
"pes",
"pes_Latn",
"pes_Thaa",
"pnb",
"pus",
"rom",
"san_Deva",
"sin",
"snd_Arab",
"tgk_Cyrl",
"tly_Latn",
"urd",
"zza",
},
),
"ilo": ("Iloko", {"ilo"}),
"inc": (
"Indic languages",
{
"asm",
"awa",
"ben",
"bho",
"gom",
"guj",
"hif_Latn",
"hin",
"mai",
"mar",
"npi",
"ori",
"pan_Guru",
"pnb",
"rom",
"san_Deva",
"sin",
"snd_Arab",
"urd",
},
),
"ine": (
"Indo-European languages",
{
"afr",
"afr_Arab",
"aln",
"ang_Latn",
"arg",
"asm",
"ast",
"awa",
"bel",
"bel_Latn",
"ben",
"bho",
"bjn",
"bos_Latn",
"bre",
"bul",
"bul_Latn",
"cat",
"ces",
"cor",
"cos",
"csb_Latn",
"cym",
"dan",
"deu",
"dsb",
"egl",
"ell",
"eng",
"enm_Latn",
"ext",
"fao",
"fra",
"frm_Latn",
"frr",
"fry",
"gcf_Latn",
"gla",
"gle",
"glg",
"glv",
"gom",
"gos",
"got_Goth",
"grc_Grek",
"gsw",
"guj",
"hat",
"hif_Latn",
"hin",
"hrv",
"hsb",
"hye",
"hye_Latn",
"ind",
"isl",
"ita",
"jdt_Cyrl",
"ksh",
"kur_Arab",
"kur_Latn",
"lad",
"lad_Latn",
"lat_Grek",
"lat_Latn",
"lav",
"lij",
"lit",
"lld_Latn",
"lmo",
"ltg",
"ltz",
"mai",
"mar",
"max_Latn",
"mfe",
"min",
"mkd",
"mwl",
"nds",
"nld",
"nno",
"nob",
"nob_Hebr",
"non_Latn",
"npi",
"oci",
"ori",
"orv_Cyrl",
"oss",
"pan_Guru",
"pap",
"pcd",
"pdc",
"pes",
"pes_Latn",
"pes_Thaa",
"pms",
"pnb",
"pol",
"por",
"prg_Latn",
"pus",
"roh",
"rom",
"ron",
"rue",
"rus",
"rus_Latn",
"san_Deva",
"scn",
"sco",
"sgs",
"sin",
"slv",
"snd_Arab",
"spa",
"sqi",
"srd",
"srp_Cyrl",
"srp_Latn",
"stq",
"swe",
"swg",
"tgk_Cyrl",
"tly_Latn",
"tmw_Latn",
"ukr",
"urd",
"vec",
"wln",
"yid",
"zlm_Latn",
"zsm_Latn",
"zza",
},
),
"isl": ("Icelandic", {"isl"}),
"ita": ("Italian", {"ita"}),
"itc": (
"Italic languages",
{
"arg",
"ast",
"bjn",
"cat",
"cos",
"egl",
"ext",
"fra",
"frm_Latn",
"gcf_Latn",
"glg",
"hat",
"ind",
"ita",
"lad",
"lad_Latn",
"lat_Grek",
"lat_Latn",
"lij",
"lld_Latn",
"lmo",
"max_Latn",
"mfe",
"min",
"mwl",
"oci",
"pap",
"pcd",
"pms",
"por",
"roh",
"ron",
"scn",
"spa",
"srd",
"tmw_Latn",
"vec",
"wln",
"zlm_Latn",
"zsm_Latn",
},
),
"jpn": ("Japanese", {"jpn", "jpn_Bopo", "jpn_Hang", "jpn_Hani", "jpn_Hira", "jpn_Kana", "jpn_Latn", "jpn_Yiii"}),
"jpx": ("Japanese (family)", {"jpn"}),
"kat": ("Georgian", {"kat"}),
"kor": ("Korean", {"kor_Hani", "kor_Hang", "kor_Latn", "kor"}),
"lav": ("Latvian", {"lav"}),
"lit": ("Lithuanian", {"lit"}),
"mkd": ("Macedonian", {"mkd"}),
"mkh": ("Mon-Khmer languages", {"vie_Hani", "mnw", "vie", "kha", "khm_Latn", "khm"}),
"msa": ("Malay (macrolanguage)", {"zsm_Latn", "ind", "max_Latn", "zlm_Latn", "min"}),
"mul": (
"Multiple languages",
{
"abk",
"acm",
"ady",
"afb",
"afh_Latn",
"afr",
"akl_Latn",
"aln",
"amh",
"ang_Latn",
"apc",
"ara",
"arg",
"arq",
"ary",
"arz",
"asm",
"ast",
"avk_Latn",
"awa",
"aze_Latn",
"bak",
"bam_Latn",
"bel",
"bel_Latn",
"ben",
"bho",
"bod",
"bos_Latn",
"bre",
"brx",
"brx_Latn",
"bul",
"bul_Latn",
"cat",
"ceb",
"ces",
"cha",
"che",
"chr",
"chv",
"cjy_Hans",
"cjy_Hant",
"cmn",
"cmn_Hans",
"cmn_Hant",
"cor",
"cos",
"crh",
"crh_Latn",
"csb_Latn",
"cym",
"dan",
"deu",
"dsb",
"dtp",
"dws_Latn",
"egl",
"ell",
"enm_Latn",
"epo",
"est",
"eus",
"ewe",
"ext",
"fao",
"fij",
"fin",
"fkv_Latn",
"fra",
"frm_Latn",
"frr",
"fry",
"fuc",
"fuv",
"gan",
"gcf_Latn",
"gil",
"gla",
"gle",
"glg",
"glv",
"gom",
"gos",
"got_Goth",
"grc_Grek",
"grn",
"gsw",
"guj",
"hat",
"hau_Latn",
"haw",
"heb",
"hif_Latn",
"hil",
"hin",
"hnj_Latn",
"hoc",
"hoc_Latn",
"hrv",
"hsb",
"hun",
"hye",
"iba",
"ibo",
"ido",
"ido_Latn",
"ike_Latn",
"ile_Latn",
"ilo",
"ina_Latn",
"ind",
"isl",
"ita",
"izh",
"jav",
"jav_Java",
"jbo",
"jbo_Cyrl",
"jbo_Latn",
"jdt_Cyrl",
"jpn",
"kab",
"kal",
"kan",
"kat",
"kaz_Cyrl",
"kaz_Latn",
"kek_Latn",
"kha",
"khm",
"khm_Latn",
"kin",
"kir_Cyrl",
"kjh",
"kpv",
"krl",
"ksh",
"kum",
"kur_Arab",
"kur_Latn",
"lad",
"lad_Latn",
"lao",
"lat_Latn",
"lav",
"ldn_Latn",
"lfn_Cyrl",
"lfn_Latn",
"lij",
"lin",
"lit",
"liv_Latn",
"lkt",
"lld_Latn",
"lmo",
"ltg",
"ltz",
"lug",
"lzh",
"lzh_Hans",
"mad",
"mah",
"mai",
"mal",
"mar",
"max_Latn",
"mdf",
"mfe",
"mhr",
"mic",
"min",
"mkd",
"mlg",
"mlt",
"mnw",
"moh",
"mon",
"mri",
"mwl",
"mww",
"mya",
"myv",
"nan",
"nau",
"nav",
"nds",
"niu",
"nld",
"nno",
"nob",
"nob_Hebr",
"nog",
"non_Latn",
"nov_Latn",
"npi",
"nya",
"oci",
"ori",
"orv_Cyrl",
"oss",
"ota_Arab",
"ota_Latn",
"pag",
"pan_Guru",
"pap",
"pau",
"pdc",
"pes",
"pes_Latn",
"pes_Thaa",
"pms",
"pnb",
"pol",
"por",
"ppl_Latn",
"prg_Latn",
"pus",
"quc",
"qya",
"qya_Latn",
"rap",
"rif_Latn",
"roh",
"rom",
"ron",
"rue",
"run",
"rus",
"sag",
"sah",
"san_Deva",
"scn",
"sco",
"sgs",
"shs_Latn",
"shy_Latn",
"sin",
"sjn_Latn",
"slv",
"sma",
"sme",
"smo",
"sna",
"snd_Arab",
"som",
"spa",
"sqi",
"srp_Cyrl",
"srp_Latn",
"stq",
"sun",
"swe",
"swg",
"swh",
"tah",
"tam",
"tat",
"tat_Arab",
"tat_Latn",
"tel",
"tet",
"tgk_Cyrl",
"tha",
"tir",
"tlh_Latn",
"tly_Latn",
"tmw_Latn",
"toi_Latn",
"ton",
"tpw_Latn",
"tso",
"tuk",
"tuk_Latn",
"tur",
"tvl",
"tyv",
"tzl",
"tzl_Latn",
"udm",
"uig_Arab",
"uig_Cyrl",
"ukr",
"umb",
"urd",
"uzb_Cyrl",
"uzb_Latn",
"vec",
"vie",
"vie_Hani",
"vol_Latn",
"vro",
"war",
"wln",
"wol",
"wuu",
"xal",
"xho",
"yid",
"yor",
"yue",
"yue_Hans",
"yue_Hant",
"zho",
"zho_Hans",
"zho_Hant",
"zlm_Latn",
"zsm_Latn",
"zul",
"zza",
},
),
"nic": (
"Niger-Kordofanian languages",
{
"bam_Latn",
"ewe",
"fuc",
"fuv",
"ibo",
"kin",
"lin",
"lug",
"nya",
"run",
"sag",
"sna",
"swh",
"toi_Latn",
"tso",
"umb",
"wol",
"xho",
"yor",
"zul",
},
),
"nld": ("Dutch", {"nld"}),
"nor": ("Norwegian", {"nob", "nno"}),
"phi": ("Philippine languages", {"ilo", "akl_Latn", "war", "hil", "pag", "ceb"}),
"pol": ("Polish", {"pol"}),
"por": ("Portuguese", {"por"}),
"pqe": (
"Eastern Malayo-Polynesian languages",
{"fij", "gil", "haw", "mah", "mri", "nau", "niu", "rap", "smo", "tah", "ton", "tvl"},
),
"roa": (
"Romance languages",
{
"arg",
"ast",
"cat",
"cos",
"egl",
"ext",
"fra",
"frm_Latn",
"gcf_Latn",
"glg",
"hat",
"ind",
"ita",
"lad",
"lad_Latn",
"lij",
"lld_Latn",
"lmo",
"max_Latn",
"mfe",
"min",
"mwl",
"oci",
"pap",
"pms",
"por",
"roh",
"ron",
"scn",
"spa",
"tmw_Latn",
"vec",
"wln",
"zlm_Latn",
"zsm_Latn",
},
),
"ron": ("Romanian", {"ron"}),
"run": ("Rundi", {"run"}),
"rus": ("Russian", {"rus"}),
"sal": ("Salishan languages", {"shs_Latn"}),
"sem": ("Semitic languages", {"acm", "afb", "amh", "apc", "ara", "arq", "ary", "arz", "heb", "mlt", "tir"}),
"sla": (
"Slavic languages",
{
"bel",
"bel_Latn",
"bos_Latn",
"bul",
"bul_Latn",
"ces",
"csb_Latn",
"dsb",
"hrv",
"hsb",
"mkd",
"orv_Cyrl",
"pol",
"rue",
"rus",
"slv",
"srp_Cyrl",
"srp_Latn",
"ukr",
},
),
"slv": ("Slovenian", {"slv"}),
"spa": ("Spanish", {"spa"}),
"swe": ("Swedish", {"swe"}),
"taw": ("Tai", {"lao", "tha"}),
"tgl": ("Tagalog", {"tgl_Latn"}),
"tha": ("Thai", {"tha"}),
"trk": (
"Turkic languages",
{
"aze_Latn",
"bak",
"chv",
"crh",
"crh_Latn",
"kaz_Cyrl",
"kaz_Latn",
"kir_Cyrl",
"kjh",
"kum",
"ota_Arab",
"ota_Latn",
"sah",
"tat",
"tat_Arab",
"tat_Latn",
"tuk",
"tuk_Latn",
"tur",
"tyv",
"uig_Arab",
"uig_Cyrl",
"uzb_Cyrl",
"uzb_Latn",
},
),
"tur": ("Turkish", {"tur"}),
"ukr": ("Ukrainian", {"ukr"}),
"urd": ("Urdu", {"urd"}),
"urj": (
"Uralic languages",
{
"est",
"fin",
"fkv_Latn",
"hun",
"izh",
"kpv",
"krl",
"liv_Latn",
"mdf",
"mhr",
"myv",
"sma",
"sme",
"udm",
"vep",
"vro",
},
),
"vie": ("Vietnamese", {"vie", "vie_Hani"}),
"war": ("Waray (Philippines)", {"war"}),
"zho": (
"Chinese",
{
"cjy_Hans",
"cjy_Hant",
"cmn",
"cmn_Bopo",
"cmn_Hang",
"cmn_Hani",
"cmn_Hans",
"cmn_Hant",
"cmn_Hira",
"cmn_Kana",
"cmn_Latn",
"cmn_Yiii",
"gan",
"hak_Hani",
"lzh",
"lzh_Bopo",
"lzh_Hang",
"lzh_Hani",
"lzh_Hans",
"lzh_Hira",
"lzh_Kana",
"lzh_Yiii",
"nan",
"nan_Hani",
"wuu",
"wuu_Bopo",
"wuu_Hani",
"wuu_Latn",
"yue",
"yue_Bopo",
"yue_Hang",
"yue_Hani",
"yue_Hans",
"yue_Hant",
"yue_Hira",
"yue_Kana",
"zho",
"zho_Hans",
"zho_Hant",
},
),
"zle": ("East Slavic languages", {"bel", "orv_Cyrl", "bel_Latn", "rus", "ukr", "rue"}),
"zls": ("South Slavic languages", {"bos_Latn", "bul", "bul_Latn", "hrv", "mkd", "slv", "srp_Cyrl", "srp_Latn"}),
"zlw": ("West Slavic languages", {"csb_Latn", "dsb", "hsb", "pol", "ces"}),
}
def l2front_matter(langs):
return "".join(f"- {l}\n" for l in langs)
def dedup(lst):
"""Preservers order"""
new_lst = []
for item in lst:
if not item:
continue
elif item in new_lst:
continue
else:
new_lst.append(item)
return new_lst
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--models", action="append", help="<Required> Set flag", required=True, nargs="+", dest="models"
)
parser.add_argument("-save_dir", "--save_dir", default="marian_converted", help="where to save converted models")
args = parser.parse_args()
resolver = TatoebaConverter(save_dir=args.save_dir)
resolver.convert_models(args.models[0])
import argparse
import json
import os
import shutil
import socket
import time
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Union
from zipfile import ZipFile
import numpy as np
......@@ -23,85 +22,6 @@ def remove_suffix(text: str, suffix: str):
return text # or whatever
def _process_benchmark_table_row(x):
fields = lmap(str.strip, x.replace("\t", "").split("|")[1:-1])
assert len(fields) == 3
return (fields[0], float(fields[1]), float(fields[2]))
def process_last_benchmark_table(readme_path) -> List[Tuple[str, float, float]]:
md_content = Path(readme_path).open().read()
entries = md_content.split("## Benchmarks")[-1].strip().split("\n")[2:]
data = lmap(_process_benchmark_table_row, entries)
return data
def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo_path="Tatoeba-Challenge/models/"):
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
import pandas as pd
newest_released, old_reg, released = get_released_df(new_repo_path, old_repo_path)
short_to_new_bleu = newest_released.set_index("short_pair").bleu
assert released.groupby("short_pair").pair.nunique().max() == 1
short_to_long = released.groupby("short_pair").pair.first().to_dict()
overlap_short = old_reg.index.intersection(released.short_pair.unique())
overlap_long = [short_to_long[o] for o in overlap_short]
new_reported_bleu = [short_to_new_bleu[o] for o in overlap_short]
def get_old_bleu(o) -> float:
pat = old_repo_path + "/{}/README.md"
bm_data = process_last_benchmark_table(pat.format(o))
tab = pd.DataFrame(bm_data, columns=["testset", "bleu", "chr-f"])
tato_bleu = tab.loc[lambda x: x.testset.str.startswith("Tato")].bleu
if tato_bleu.shape[0] > 0:
return tato_bleu.iloc[0]
else:
return np.nan
old_bleu = [get_old_bleu(o) for o in overlap_short]
cmp_df = pd.DataFrame(
dict(short=overlap_short, long=overlap_long, old_bleu=old_bleu, new_bleu=new_reported_bleu)
).fillna(-1)
dominated = cmp_df[cmp_df.old_bleu > cmp_df.new_bleu]
whitelist_df = cmp_df[cmp_df.old_bleu <= cmp_df.new_bleu]
blacklist = dominated.long.unique().tolist() # 3 letter codes
return whitelist_df, dominated, blacklist
def get_released_df(new_repo_path, old_repo_path):
import pandas as pd
released_cols = [
"url_base",
"pair", # (ISO639-3/ISO639-5 codes),
"short_pair", # (reduced codes),
"chrF2_score",
"bleu",
"brevity_penalty",
"ref_len",
"src_name",
"tgt_name",
]
released = pd.read_csv(f"{new_repo_path}/released-models.txt", sep="\t", header=None).iloc[:-1]
released.columns = released_cols
old_reg = make_registry(repo_path=old_repo_path)
old_reg = pd.DataFrame(old_reg, columns=["id", "prepro", "url_model", "url_test_set"])
assert old_reg.id.value_counts().max() == 1
old_reg = old_reg.set_index("id")
released["fname"] = released["url_base"].apply(
lambda x: remove_suffix(remove_prefix(x, "https://object.pouta.csc.fi/Tatoeba-Challenge/opus"), ".zip")
)
released["2m"] = released.fname.str.startswith("2m")
released["date"] = pd.to_datetime(released["fname"].apply(lambda x: remove_prefix(remove_prefix(x, "2m-"), "-")))
newest_released = released.dsort("date").drop_duplicates(["short_pair"], keep="first")
return newest_released, old_reg, released
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
......@@ -183,7 +103,11 @@ def find_model_file(dest_dir): # this one better
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
ROM_GROUP = "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
ROM_GROUP = (
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT"
"+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co"
"+nap+scn+vec+sc+ro+la"
)
GROUPS = [
("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
(ROM_GROUP, "ROMANCE"),
......@@ -221,13 +145,15 @@ ORG_NAME = "Helsinki-NLP/"
def convert_opus_name_to_hf_name(x):
"""For OPUS-MT-Train/ DEPRECATED"""
for substr, grp_name in GROUPS:
x = x.replace(substr, grp_name)
return x.replace("+", "_")
def convert_hf_name_to_opus_name(hf_model_name):
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
"""Relies on the assumption that there are no language codes like pt_br in models that are not in
GROUP_TO_OPUS_NAME."""
hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
if hf_model_name in GROUP_TO_OPUS_NAME:
opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
......@@ -247,8 +173,9 @@ def get_system_metadata(repo_root):
)
front_matter = """---
language: {}
FRONT_MATTER_TEMPLATE = """---
language:
{}
tags:
- translation
......@@ -256,11 +183,13 @@ license: apache-2.0
---
"""
DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
def write_model_card(
hf_model_name: str,
repo_root="OPUS-MT-train",
repo_root=DEFAULT_REPO,
save_dir=Path("marian_converted"),
dry_run=False,
extra_metadata={},
......@@ -294,7 +223,10 @@ def write_model_card(
# combine with opus markdown
extra_markdown = f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: {metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
extra_markdown = (
f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: "
f"{metadata['tgt_name']} \n* OPUS readme: [{opus_name}]({readme_url})\n"
)
content = opus_readme_path.open().read()
content = content.split("\n# ")[-1] # Get the lowest level 1 header in the README -- the most recent model.
......@@ -302,7 +234,7 @@ def write_model_card(
print(splat[3])
content = "*".join(splat)
content = (
front_matter.format(metadata["src_alpha2"])
FRONT_MATTER_TEMPLATE.format(metadata["src_alpha2"])
+ extra_markdown
+ "\n* "
+ content.replace("download", "download original weights")
......@@ -323,48 +255,6 @@ def write_model_card(
return content, metadata
def get_clean_model_id_mapping(multiling_model_ids):
return {x: convert_opus_name_to_hf_name(x) for x in multiling_model_ids}
def expand_group_to_two_letter_codes(grp_name):
raise NotImplementedError()
def get_two_letter_code(three_letter_code):
raise NotImplementedError()
# return two_letter_code
def get_tags(code, ref_name):
if len(code) == 2:
assert "languages" not in ref_name, f"{code}: {ref_name}"
return [code], False
elif "languages" in ref_name:
group = expand_group_to_two_letter_codes(code)
group.append(code)
return group, True
else: # zho-> zh
raise ValueError(f"Three letter monolingual code: {code}")
def resolve_lang_code(r):
"""R is a row in ported"""
short_pair = r.short_pair
src, tgt = short_pair.split("-")
src_tags, src_multilingual = get_tags(src, r.src_name)
assert isinstance(src_tags, list)
tgt_tags, tgt_multilingual = get_tags(src, r.tgt_name)
assert isinstance(tgt_tags, list)
if src_multilingual:
src_tags.append("multilingual_src")
if tgt_multilingual:
tgt_tags.append("multilingual_tgt")
return src_tags + tgt_tags
# process target
def make_registry(repo_path="Opus-MT-train/models"):
if not (Path(repo_path) / "fr-en" / "README.md").exists():
raise ValueError(
......@@ -382,36 +272,25 @@ def make_registry(repo_path="Opus-MT-train/models"):
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
def make_tatoeba_registry(repo_path="Tatoeba-Challenge/models"):
if not (Path(repo_path) / "zho-eng" / "README.md").exists():
raise ValueError(
f"repo_path:{repo_path} does not exist: "
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
)
results = {}
for p in Path(repo_path).iterdir():
if len(p.name) != 7:
continue
lns = list(open(p / "README.md").readlines())
results[p.name] = _parse_readme(lns)
return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]
def convert_all_sentencepiece_models(model_list=None, repo_path=None):
def convert_all_sentencepiece_models(model_list=None, repo_path=None, dest_dir=Path("marian_converted")):
"""Requires 300GB"""
save_dir = Path("marian_ckpt")
dest_dir = Path("marian_converted")
dest_dir = Path(dest_dir)
dest_dir.mkdir(exist_ok=True)
save_paths = []
if model_list is None:
model_list: list = make_registry(repo_path=repo_path)
for k, prepro, download, test_set_url in tqdm(model_list):
if "SentencePiece" not in prepro: # dont convert BPE models.
continue
if not os.path.exists(save_dir / k / "pytorch_model.bin"):
if not os.path.exists(save_dir / k):
download_and_unzip(download, save_dir / k)
pair_name = convert_opus_name_to_hf_name(k)
convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")
save_paths.append(dest_dir / f"opus-mt-{pair_name}")
return save_paths
def lmap(f, x) -> List:
return list(map(f, x))
......@@ -493,15 +372,6 @@ def add_special_tokens_to_vocab(model_dir: Path) -> None:
save_tokenizer_config(model_dir)
def save_tokenizer(self, save_directory):
dest = Path(save_directory)
src_path = Path(self.init_kwargs["source_spm"])
for dest_name in {"source.spm", "target.spm", "tokenizer_config.json"}:
shutil.copyfile(src_path.parent / dest_name, dest / dest_name)
save_json(self.encoder, dest / "vocab.json")
def check_equal(marian_cfg, k1, k2):
v1, v2 = marian_cfg[k1], marian_cfg[k2]
assert v1 == v2, f"hparams {k1},{k2} differ: {v1} != {v2}"
......@@ -698,14 +568,14 @@ def convert(source_dir: Path, dest_dir):
add_special_tokens_to_vocab(source_dir)
tokenizer = MarianTokenizer.from_pretrained(str(source_dir))
save_tokenizer(tokenizer, dest_dir)
tokenizer.save_pretrained(dest_dir)
opus_state = OpusState(source_dir)
assert opus_state.cfg["vocab_size"] == len(
tokenizer.encoder
), f"Original vocab size {opus_state.cfg['vocab_size']} and new vocab size {len(tokenizer.encoder)} mismatched"
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
# ^^ Save human readable marian config for debugging
# ^^ Uncomment to save human readable marian config for debugging
model = opus_state.load_marian_model()
model = model.half()
......@@ -732,15 +602,11 @@ def unzip(zip_path: str, dest_dir: str) -> None:
if __name__ == "__main__":
"""
To bulk convert, run
>>> from transformers.convert_marian_to_pytorch import make_tatoeba_registry, convert_all_sentencepiece_models
>>> reg = make_tatoeba_registry()
>>> convert_all_sentencepiece_models(model_list=reg) # saves to marian_converted
(bash) aws s3 sync marian_converted s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
Tatoeba conversion instructions in scripts/tatoeba/README.md
"""
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--src", type=str, help="path to marian model dir", default="en-de")
parser.add_argument("--src", type=str, help="path to marian model sub dir", default="en-de")
parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment