Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c2b2db2
Unverified
Commit
9c2b2db2
authored
Oct 12, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 12, 2020
Browse files
[marian] Automate Tatoeba-Challenge conversion (#7709)
parent
aacac8f7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1347 additions
and
165 deletions
+1347
-165
.gitignore
.gitignore
+1
-0
examples/seq2seq/test_tatoeba_conversion.py
examples/seq2seq/test_tatoeba_conversion.py
+22
-0
scripts/tatoeba/README.md
scripts/tatoeba/README.md
+44
-0
src/transformers/convert_marian_tatoeba_to_pytorch.py
src/transformers/convert_marian_tatoeba_to_pytorch.py
+1249
-0
src/transformers/convert_marian_to_pytorch.py
src/transformers/convert_marian_to_pytorch.py
+31
-165
No files found.
.gitignore
View file @
9c2b2db2
...
...
@@ -12,6 +12,7 @@ __pycache__/
tests/fixtures
logs/
lightning_logs/
lang_code_data/
# Distribution / packaging
.Python
...
...
examples/seq2seq/test_tatoeba_conversion.py
0 → 100644
View file @
9c2b2db2
import
tempfile
import
unittest
from
transformers.convert_marian_tatoeba_to_pytorch
import
TatoebaConverter
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
slow
class
TatoebaConversionTester
(
unittest
.
TestCase
):
@
cached_property
def
resolver
(
self
):
tmp_dir
=
tempfile
.
mkdtemp
()
return
TatoebaConverter
(
save_dir
=
tmp_dir
)
@
slow
def
test_resolver
(
self
):
self
.
resolver
.
convert_models
([
"heb-eng"
])
@
slow
def
test_model_card
(
self
):
content
,
mmeta
=
self
.
resolver
.
write_model_card
(
"opus-mt-he-en"
,
dry_run
=
True
)
assert
mmeta
[
"long_pair"
]
==
"heb-eng"
scripts/tatoeba/README.md
0 → 100644
View file @
9c2b2db2
Setup transformers following instructions in README.md, (I would fork first).
```
bash
git clone git@github.com:huggingface/transformers.git
cd
transformers
pip
install
-e
.
pip
install
pandas
```
Get required metadata
```
curl https://cdn-datasets.huggingface.co/language_codes/language-codes-3b2.csv > language-codes-3b2.csv
curl https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv > iso-639-3.csv
```
Install Tatoeba-Challenge repo inside transformers
```
bash
git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git
```
To convert a few models, call the conversion script from command line:
```
bash
python src/transformers/convert_marian_tatoeba_to_pytorch.py
--models
heb-eng eng-heb
--save_dir
converted
```
To convert lots of models you can pass your list of Tatoeba model names to
`resolver.convert_models`
in a python client or script.
```
python
from
transformers.convert_marian_tatoeba_to_pytorch
import
TatoebaConverter
resolver
=
TatoebaConverter
(
save_dir
=
'converted'
)
resolver
.
convert_models
([
'heb-eng'
,
'eng-heb'
])
```
### Upload converted models
```
bash
cd
converted
transformers-cli login
for
FILE
in
*
;
do
transformers-cli upload
$FILE
;
done
```
### Modifications
-
To change naming logic, change the code near
`os.rename`
. The model card creation code may also need to change.
-
To change model card content, you must modify
`TatoebaCodeResolver.write_model_card`
src/transformers/convert_marian_tatoeba_to_pytorch.py
0 → 100644
View file @
9c2b2db2
import
argparse
import
os
from
pathlib
import
Path
from
typing
import
List
,
Tuple
from
transformers.convert_marian_to_pytorch
import
(
FRONT_MATTER_TEMPLATE
,
_parse_readme
,
convert_all_sentencepiece_models
,
get_system_metadata
,
remove_prefix
,
remove_suffix
,
)
try
:
import
pandas
as
pd
except
ImportError
:
pass
DEFAULT_REPO
=
"Tatoeba-Challenge"
DEFAULT_MODEL_DIR
=
os
.
path
.
join
(
DEFAULT_REPO
,
"models"
)
LANG_CODE_URL
=
"https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
ISO_URL
=
"https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv"
ISO_PATH
=
"lang_code_data/iso-639-3.csv"
LANG_CODE_PATH
=
"lang_code_data/language-codes-3b2.csv"
class
TatoebaConverter
:
"""Convert Tatoeba-Challenge models to huggingface format.
Steps:
(1) convert numpy state dict to hf format (same code as OPUS-MT-Train conversion).
(2) rename opus model to huggingface format. This means replace each alpha3 code with an alpha2 code if a unique one existes.
e.g. aav-eng -> aav-en, heb-eng -> he-en
(3) write a model card containing the original Tatoeba-Challenge/README.md and extra info about alpha3 group members.
"""
def
__init__
(
self
,
save_dir
=
"marian_converted"
):
assert
Path
(
DEFAULT_REPO
).
exists
(),
"need git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git"
reg
=
self
.
make_tatoeba_registry
()
self
.
download_metadata
()
self
.
registry
=
reg
reg_df
=
pd
.
DataFrame
(
reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
reg_df
.
id
.
value_counts
().
max
()
==
1
reg_df
=
reg_df
.
set_index
(
"id"
)
reg_df
[
"src"
]
=
reg_df
.
reset_index
().
id
.
apply
(
lambda
x
:
x
.
split
(
"-"
)[
0
]).
values
reg_df
[
"tgt"
]
=
reg_df
.
reset_index
().
id
.
apply
(
lambda
x
:
x
.
split
(
"-"
)[
1
]).
values
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
"Tatoeba-Challenge/models/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
))
)
released
[
"base_ext"
]
=
released
.
url_base
.
apply
(
lambda
x
:
Path
(
x
).
name
)
reg_df
[
"base_ext"
]
=
reg_df
.
url_model
.
apply
(
lambda
x
:
Path
(
x
).
name
)
metadata_new
=
reg_df
.
reset_index
().
merge
(
released
.
rename
(
columns
=
{
"pair"
:
"id"
}),
on
=
[
"base_ext"
,
"id"
])
metadata_renamer
=
{
"src"
:
"src_alpha3"
,
"tgt"
:
"tgt_alpha3"
,
"id"
:
"long_pair"
,
"date"
:
"train_date"
}
metadata_new
=
metadata_new
.
rename
(
columns
=
metadata_renamer
)
metadata_new
[
"src_alpha2"
]
=
metadata_new
.
short_pair
.
apply
(
lambda
x
:
x
.
split
(
"-"
)[
0
])
metadata_new
[
"tgt_alpha2"
]
=
metadata_new
.
short_pair
.
apply
(
lambda
x
:
x
.
split
(
"-"
)[
1
])
DROP_COLS_BOTH
=
[
"url_base"
,
"base_ext"
,
"fname"
]
metadata_new
=
metadata_new
.
drop
(
DROP_COLS_BOTH
,
1
)
metadata_new
[
"prefer_old"
]
=
metadata_new
.
long_pair
.
isin
([])
self
.
metadata
=
metadata_new
assert
self
.
metadata
.
short_pair
.
value_counts
().
max
()
==
1
,
"Multiple metadata entries for a short pair"
self
.
metadata
=
self
.
metadata
.
set_index
(
"short_pair"
)
# wget.download(LANG_CODE_URL)
mapper
=
pd
.
read_csv
(
LANG_CODE_PATH
)
mapper
.
columns
=
[
"a3"
,
"a2"
,
"ref"
]
self
.
iso_table
=
pd
.
read_csv
(
ISO_PATH
,
sep
=
"
\t
"
).
rename
(
columns
=
lambda
x
:
x
.
lower
())
more_3_to_2
=
self
.
iso_table
.
set_index
(
"id"
).
part1
.
dropna
().
to_dict
()
more_3_to_2
.
update
(
mapper
.
set_index
(
"a3"
).
a2
.
to_dict
())
self
.
alpha3_to_alpha2
=
more_3_to_2
self
.
model_card_dir
=
Path
(
save_dir
)
self
.
constituents
=
GROUP_MEMBERS
def
convert_models
(
self
,
tatoeba_ids
,
dry_run
=
False
):
entries_to_convert
=
[
x
for
x
in
self
.
registry
if
x
[
0
]
in
tatoeba_ids
]
converted_paths
=
convert_all_sentencepiece_models
(
entries_to_convert
,
dest_dir
=
self
.
model_card_dir
)
for
path
in
converted_paths
:
long_pair
=
remove_prefix
(
path
.
name
,
"opus-mt-"
).
split
(
"-"
)
# eg. heb-eng
assert
len
(
long_pair
)
==
2
new_p_src
=
self
.
get_two_letter_code
(
long_pair
[
0
])
new_p_tgt
=
self
.
get_two_letter_code
(
long_pair
[
1
])
hf_model_id
=
f
"opus-mt-
{
new_p_src
}
-
{
new_p_tgt
}
"
new_path
=
path
.
parent
.
joinpath
(
hf_model_id
)
# opus-mt-he-en
os
.
rename
(
str
(
path
),
str
(
new_path
))
self
.
write_model_card
(
hf_model_id
,
dry_run
=
dry_run
)
def
get_two_letter_code
(
self
,
three_letter_code
):
return
self
.
alpha3_to_alpha2
.
get
(
three_letter_code
,
three_letter_code
)
def
expand_group_to_two_letter_codes
(
self
,
grp_name
):
return
[
self
.
get_two_letter_code
(
x
)
for
x
in
self
.
constituents
[
grp_name
]]
def
get_tags
(
self
,
code
,
ref_name
):
if
len
(
code
)
==
2
:
assert
"languages"
not
in
ref_name
,
f
"
{
code
}
:
{
ref_name
}
"
return
[
code
],
False
elif
"languages"
in
ref_name
or
len
(
self
.
constituents
.
get
(
code
,
[]))
>
1
:
group
=
self
.
expand_group_to_two_letter_codes
(
code
)
group
.
append
(
code
)
return
group
,
True
else
:
# zho-> zh
print
(
f
"Three letter monolingual code:
{
code
}
"
)
return
[
code
],
False
def
resolve_lang_code
(
self
,
r
)
->
Tuple
[
List
[
str
],
str
,
str
]:
"""R is a row in ported"""
short_pair
=
r
.
short_pair
src
,
tgt
=
short_pair
.
split
(
"-"
)
src_tags
,
src_multilingual
=
self
.
get_tags
(
src
,
r
.
src_name
)
assert
isinstance
(
src_tags
,
list
)
tgt_tags
,
tgt_multilingual
=
self
.
get_tags
(
tgt
,
r
.
tgt_name
)
assert
isinstance
(
tgt_tags
,
list
)
return
dedup
(
src_tags
+
tgt_tags
),
src_multilingual
,
tgt_multilingual
def
write_model_card
(
self
,
hf_model_id
:
str
,
repo_root
=
DEFAULT_REPO
,
dry_run
=
False
,
)
->
str
:
"""Copy the most recent model's readme section from opus, and add metadata.
upload command: aws s3 sync model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
"""
short_pair
=
remove_prefix
(
hf_model_id
,
"opus-mt-"
)
extra_metadata
=
self
.
metadata
.
loc
[
short_pair
].
drop
(
"2m"
)
extra_metadata
[
"short_pair"
]
=
short_pair
lang_tags
,
src_multilingual
,
tgt_multilingual
=
self
.
resolve_lang_code
(
extra_metadata
)
opus_name
=
f
"
{
extra_metadata
.
src_alpha3
}
-
{
extra_metadata
.
tgt_alpha3
}
"
# opus_name: str = self.convert_hf_name_to_opus_name(hf_model_name)
assert
repo_root
in
(
"OPUS-MT-train"
,
"Tatoeba-Challenge"
)
opus_readme_path
=
Path
(
repo_root
).
joinpath
(
"models"
,
opus_name
,
"README.md"
)
assert
opus_readme_path
.
exists
(),
f
"Readme file
{
opus_readme_path
}
not found"
opus_src
,
opus_tgt
=
[
x
.
split
(
"+"
)
for
x
in
opus_name
.
split
(
"-"
)]
readme_url
=
f
"https://github.com/Helsinki-NLP/
{
repo_root
}
/tree/master/models/
{
opus_name
}
/README.md"
s
,
t
=
","
.
join
(
opus_src
),
","
.
join
(
opus_tgt
)
metadata
=
{
"hf_name"
:
short_pair
,
"source_languages"
:
s
,
"target_languages"
:
t
,
"opus_readme_url"
:
readme_url
,
"original_repo"
:
repo_root
,
"tags"
:
[
"translation"
],
"languages"
:
lang_tags
,
}
lang_tags
=
l2front_matter
(
lang_tags
)
metadata
[
"src_constituents"
]
=
self
.
constituents
[
s
]
metadata
[
"tgt_constituents"
]
=
self
.
constituents
[
t
]
metadata
[
"src_multilingual"
]
=
src_multilingual
metadata
[
"tgt_multilingual"
]
=
tgt_multilingual
metadata
.
update
(
extra_metadata
)
metadata
.
update
(
get_system_metadata
(
repo_root
))
# combine with Tatoeba markdown
extra_markdown
=
f
"###
{
short_pair
}
\n\n
* source group:
{
metadata
[
'src_name'
]
}
\n
* target group:
{
metadata
[
'tgt_name'
]
}
\n
* OPUS readme: [
{
opus_name
}
](
{
readme_url
}
)
\n
"
content
=
opus_readme_path
.
open
().
read
()
content
=
content
.
split
(
"
\n
# "
)[
-
1
]
# Get the lowest level 1 header in the README -- the most recent model.
splat
=
content
.
split
(
"*"
)[
2
:]
content
=
"*"
.
join
(
splat
)
# BETTER FRONT MATTER LOGIC
content
=
(
FRONT_MATTER_TEMPLATE
.
format
(
lang_tags
)
+
extra_markdown
+
"
\n
* "
+
content
.
replace
(
"download"
,
"download original "
"weights"
)
)
items
=
"
\n\n
"
.
join
([
f
"-
{
k
}
:
{
v
}
"
for
k
,
v
in
metadata
.
items
()])
sec3
=
"
\n
### System Info:
\n
"
+
items
content
+=
sec3
if
dry_run
:
return
content
,
metadata
sub_dir
=
self
.
model_card_dir
/
hf_model_id
sub_dir
.
mkdir
(
exist_ok
=
True
)
dest
=
sub_dir
/
"README.md"
dest
.
open
(
"w"
).
write
(
content
)
pd
.
Series
(
metadata
).
to_json
(
sub_dir
/
"metadata.json"
)
return
content
,
metadata
def
download_metadata
(
self
):
Path
(
LANG_CODE_PATH
).
parent
.
mkdir
(
exist_ok
=
True
)
import
wget
if
not
os
.
path
.
exists
(
ISO_PATH
):
wget
.
download
(
ISO_URL
,
ISO_PATH
)
if
not
os
.
path
.
exists
(
LANG_CODE_PATH
):
wget
.
download
(
LANG_CODE_URL
,
LANG_CODE_PATH
)
@
staticmethod
def
make_tatoeba_registry
(
repo_path
=
DEFAULT_MODEL_DIR
):
if
not
(
Path
(
repo_path
)
/
"zho-eng"
/
"README.md"
).
exists
():
raise
ValueError
(
f
"repo_path:
{
repo_path
}
does not exist: "
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
)
results
=
{}
for
p
in
Path
(
repo_path
).
iterdir
():
if
len
(
p
.
name
)
!=
7
:
continue
lns
=
list
(
open
(
p
/
"README.md"
).
readlines
())
results
[
p
.
name
]
=
_parse_readme
(
lns
)
return
[(
k
,
v
[
"pre-processing"
],
v
[
"download"
],
v
[
"download"
][:
-
4
]
+
".test.txt"
)
for
k
,
v
in
results
.
items
()]
GROUP_MEMBERS
=
{
# three letter code -> (group/language name, {constituents...}
# if this language is on the target side the constituents can be used as target language codes.
# if the language is on the source side they are supported natively without special codes.
"aav"
:
(
"Austro-Asiatic languages"
,
{
"hoc"
,
"hoc_Latn"
,
"kha"
,
"khm"
,
"khm_Latn"
,
"mnw"
,
"vie"
,
"vie_Hani"
}),
"afa"
:
(
"Afro-Asiatic languages"
,
{
"acm"
,
"afb"
,
"amh"
,
"apc"
,
"ara"
,
"arq"
,
"ary"
,
"arz"
,
"hau_Latn"
,
"heb"
,
"kab"
,
"mlt"
,
"rif_Latn"
,
"shy_Latn"
,
"som"
,
"thv"
,
"tir"
,
},
),
"afr"
:
(
"Afrikaans"
,
{
"afr"
}),
"alv"
:
(
"Atlantic-Congo languages"
,
{
"ewe"
,
"fuc"
,
"fuv"
,
"ibo"
,
"kin"
,
"lin"
,
"lug"
,
"nya"
,
"run"
,
"sag"
,
"sna"
,
"swh"
,
"toi_Latn"
,
"tso"
,
"umb"
,
"wol"
,
"xho"
,
"yor"
,
"zul"
,
},
),
"ara"
:
(
"Arabic"
,
{
"afb"
,
"apc"
,
"apc_Latn"
,
"ara"
,
"ara_Latn"
,
"arq"
,
"arq_Latn"
,
"arz"
}),
"art"
:
(
"Artificial languages"
,
{
"afh_Latn"
,
"avk_Latn"
,
"dws_Latn"
,
"epo"
,
"ido"
,
"ido_Latn"
,
"ile_Latn"
,
"ina_Latn"
,
"jbo"
,
"jbo_Cyrl"
,
"jbo_Latn"
,
"ldn_Latn"
,
"lfn_Cyrl"
,
"lfn_Latn"
,
"nov_Latn"
,
"qya"
,
"qya_Latn"
,
"sjn_Latn"
,
"tlh_Latn"
,
"tzl"
,
"tzl_Latn"
,
"vol_Latn"
,
},
),
"aze"
:
(
"Azerbaijani"
,
{
"aze_Latn"
}),
"bat"
:
(
"Baltic languages"
,
{
"lit"
,
"lav"
,
"prg_Latn"
,
"ltg"
,
"sgs"
}),
"bel"
:
(
"Belarusian"
,
{
"bel"
,
"bel_Latn"
}),
"ben"
:
(
"Bengali"
,
{
"ben"
}),
"bnt"
:
(
"Bantu languages"
,
{
"kin"
,
"lin"
,
"lug"
,
"nya"
,
"run"
,
"sna"
,
"swh"
,
"toi_Latn"
,
"tso"
,
"umb"
,
"xho"
,
"zul"
},
),
"bul"
:
(
"Bulgarian"
,
{
"bul"
,
"bul_Latn"
}),
"cat"
:
(
"Catalan"
,
{
"cat"
}),
"cau"
:
(
"Caucasian languages"
,
{
"abk"
,
"kat"
,
"che"
,
"ady"
}),
"ccs"
:
(
"South Caucasian languages"
,
{
"kat"
}),
"ceb"
:
(
"Cebuano"
,
{
"ceb"
}),
"cel"
:
(
"Celtic languages"
,
{
"gla"
,
"gle"
,
"bre"
,
"cor"
,
"glv"
,
"cym"
}),
"ces"
:
(
"Czech"
,
{
"ces"
}),
"cpf"
:
(
"Creoles and pidgins, French‑based"
,
{
"gcf_Latn"
,
"hat"
,
"mfe"
}),
"cpp"
:
(
"Creoles and pidgins, Portuguese-based"
,
{
"zsm_Latn"
,
"ind"
,
"pap"
,
"min"
,
"tmw_Latn"
,
"max_Latn"
,
"zlm_Latn"
},
),
"cus"
:
(
"Cushitic languages"
,
{
"som"
}),
"dan"
:
(
"Danish"
,
{
"dan"
}),
"deu"
:
(
"German"
,
{
"deu"
}),
"dra"
:
(
"Dravidian languages"
,
{
"tam"
,
"kan"
,
"mal"
,
"tel"
}),
"ell"
:
(
"Modern Greek (1453-)"
,
{
"ell"
}),
"eng"
:
(
"English"
,
{
"eng"
}),
"epo"
:
(
"Esperanto"
,
{
"epo"
}),
"est"
:
(
"Estonian"
,
{
"est"
}),
"euq"
:
(
"Basque (family)"
,
{
"eus"
}),
"eus"
:
(
"Basque"
,
{
"eus"
}),
"fin"
:
(
"Finnish"
,
{
"fin"
}),
"fiu"
:
(
"Finno-Ugrian languages"
,
{
"est"
,
"fin"
,
"fkv_Latn"
,
"hun"
,
"izh"
,
"kpv"
,
"krl"
,
"liv_Latn"
,
"mdf"
,
"mhr"
,
"myv"
,
"sma"
,
"sme"
,
"udm"
,
"vep"
,
"vro"
,
},
),
"fra"
:
(
"French"
,
{
"fra"
}),
"gem"
:
(
"Germanic languages"
,
{
"afr"
,
"ang_Latn"
,
"dan"
,
"deu"
,
"eng"
,
"enm_Latn"
,
"fao"
,
"frr"
,
"fry"
,
"gos"
,
"got_Goth"
,
"gsw"
,
"isl"
,
"ksh"
,
"ltz"
,
"nds"
,
"nld"
,
"nno"
,
"nob"
,
"nob_Hebr"
,
"non_Latn"
,
"pdc"
,
"sco"
,
"stq"
,
"swe"
,
"swg"
,
"yid"
,
},
),
"gle"
:
(
"Irish"
,
{
"gle"
}),
"glg"
:
(
"Galician"
,
{
"glg"
}),
"gmq"
:
(
"North Germanic languages"
,
{
"dan"
,
"nob"
,
"nob_Hebr"
,
"swe"
,
"isl"
,
"nno"
,
"non_Latn"
,
"fao"
}),
"gmw"
:
(
"West Germanic languages"
,
{
"afr"
,
"ang_Latn"
,
"deu"
,
"eng"
,
"enm_Latn"
,
"frr"
,
"fry"
,
"gos"
,
"gsw"
,
"ksh"
,
"ltz"
,
"nds"
,
"nld"
,
"pdc"
,
"sco"
,
"stq"
,
"swg"
,
"yid"
,
},
),
"grk"
:
(
"Greek languages"
,
{
"grc_Grek"
,
"ell"
}),
"hbs"
:
(
"Serbo-Croatian"
,
{
"hrv"
,
"srp_Cyrl"
,
"bos_Latn"
,
"srp_Latn"
}),
"heb"
:
(
"Hebrew"
,
{
"heb"
}),
"hin"
:
(
"Hindi"
,
{
"hin"
}),
"hun"
:
(
"Hungarian"
,
{
"hun"
}),
"hye"
:
(
"Armenian"
,
{
"hye"
,
"hye_Latn"
}),
"iir"
:
(
"Indo-Iranian languages"
,
{
"asm"
,
"awa"
,
"ben"
,
"bho"
,
"gom"
,
"guj"
,
"hif_Latn"
,
"hin"
,
"jdt_Cyrl"
,
"kur_Arab"
,
"kur_Latn"
,
"mai"
,
"mar"
,
"npi"
,
"ori"
,
"oss"
,
"pan_Guru"
,
"pes"
,
"pes_Latn"
,
"pes_Thaa"
,
"pnb"
,
"pus"
,
"rom"
,
"san_Deva"
,
"sin"
,
"snd_Arab"
,
"tgk_Cyrl"
,
"tly_Latn"
,
"urd"
,
"zza"
,
},
),
"ilo"
:
(
"Iloko"
,
{
"ilo"
}),
"inc"
:
(
"Indic languages"
,
{
"asm"
,
"awa"
,
"ben"
,
"bho"
,
"gom"
,
"guj"
,
"hif_Latn"
,
"hin"
,
"mai"
,
"mar"
,
"npi"
,
"ori"
,
"pan_Guru"
,
"pnb"
,
"rom"
,
"san_Deva"
,
"sin"
,
"snd_Arab"
,
"urd"
,
},
),
"ine"
:
(
"Indo-European languages"
,
{
"afr"
,
"afr_Arab"
,
"aln"
,
"ang_Latn"
,
"arg"
,
"asm"
,
"ast"
,
"awa"
,
"bel"
,
"bel_Latn"
,
"ben"
,
"bho"
,
"bjn"
,
"bos_Latn"
,
"bre"
,
"bul"
,
"bul_Latn"
,
"cat"
,
"ces"
,
"cor"
,
"cos"
,
"csb_Latn"
,
"cym"
,
"dan"
,
"deu"
,
"dsb"
,
"egl"
,
"ell"
,
"eng"
,
"enm_Latn"
,
"ext"
,
"fao"
,
"fra"
,
"frm_Latn"
,
"frr"
,
"fry"
,
"gcf_Latn"
,
"gla"
,
"gle"
,
"glg"
,
"glv"
,
"gom"
,
"gos"
,
"got_Goth"
,
"grc_Grek"
,
"gsw"
,
"guj"
,
"hat"
,
"hif_Latn"
,
"hin"
,
"hrv"
,
"hsb"
,
"hye"
,
"hye_Latn"
,
"ind"
,
"isl"
,
"ita"
,
"jdt_Cyrl"
,
"ksh"
,
"kur_Arab"
,
"kur_Latn"
,
"lad"
,
"lad_Latn"
,
"lat_Grek"
,
"lat_Latn"
,
"lav"
,
"lij"
,
"lit"
,
"lld_Latn"
,
"lmo"
,
"ltg"
,
"ltz"
,
"mai"
,
"mar"
,
"max_Latn"
,
"mfe"
,
"min"
,
"mkd"
,
"mwl"
,
"nds"
,
"nld"
,
"nno"
,
"nob"
,
"nob_Hebr"
,
"non_Latn"
,
"npi"
,
"oci"
,
"ori"
,
"orv_Cyrl"
,
"oss"
,
"pan_Guru"
,
"pap"
,
"pcd"
,
"pdc"
,
"pes"
,
"pes_Latn"
,
"pes_Thaa"
,
"pms"
,
"pnb"
,
"pol"
,
"por"
,
"prg_Latn"
,
"pus"
,
"roh"
,
"rom"
,
"ron"
,
"rue"
,
"rus"
,
"rus_Latn"
,
"san_Deva"
,
"scn"
,
"sco"
,
"sgs"
,
"sin"
,
"slv"
,
"snd_Arab"
,
"spa"
,
"sqi"
,
"srd"
,
"srp_Cyrl"
,
"srp_Latn"
,
"stq"
,
"swe"
,
"swg"
,
"tgk_Cyrl"
,
"tly_Latn"
,
"tmw_Latn"
,
"ukr"
,
"urd"
,
"vec"
,
"wln"
,
"yid"
,
"zlm_Latn"
,
"zsm_Latn"
,
"zza"
,
},
),
"isl"
:
(
"Icelandic"
,
{
"isl"
}),
"ita"
:
(
"Italian"
,
{
"ita"
}),
"itc"
:
(
"Italic languages"
,
{
"arg"
,
"ast"
,
"bjn"
,
"cat"
,
"cos"
,
"egl"
,
"ext"
,
"fra"
,
"frm_Latn"
,
"gcf_Latn"
,
"glg"
,
"hat"
,
"ind"
,
"ita"
,
"lad"
,
"lad_Latn"
,
"lat_Grek"
,
"lat_Latn"
,
"lij"
,
"lld_Latn"
,
"lmo"
,
"max_Latn"
,
"mfe"
,
"min"
,
"mwl"
,
"oci"
,
"pap"
,
"pcd"
,
"pms"
,
"por"
,
"roh"
,
"ron"
,
"scn"
,
"spa"
,
"srd"
,
"tmw_Latn"
,
"vec"
,
"wln"
,
"zlm_Latn"
,
"zsm_Latn"
,
},
),
"jpn"
:
(
"Japanese"
,
{
"jpn"
,
"jpn_Bopo"
,
"jpn_Hang"
,
"jpn_Hani"
,
"jpn_Hira"
,
"jpn_Kana"
,
"jpn_Latn"
,
"jpn_Yiii"
}),
"jpx"
:
(
"Japanese (family)"
,
{
"jpn"
}),
"kat"
:
(
"Georgian"
,
{
"kat"
}),
"kor"
:
(
"Korean"
,
{
"kor_Hani"
,
"kor_Hang"
,
"kor_Latn"
,
"kor"
}),
"lav"
:
(
"Latvian"
,
{
"lav"
}),
"lit"
:
(
"Lithuanian"
,
{
"lit"
}),
"mkd"
:
(
"Macedonian"
,
{
"mkd"
}),
"mkh"
:
(
"Mon-Khmer languages"
,
{
"vie_Hani"
,
"mnw"
,
"vie"
,
"kha"
,
"khm_Latn"
,
"khm"
}),
"msa"
:
(
"Malay (macrolanguage)"
,
{
"zsm_Latn"
,
"ind"
,
"max_Latn"
,
"zlm_Latn"
,
"min"
}),
"mul"
:
(
"Multiple languages"
,
{
"abk"
,
"acm"
,
"ady"
,
"afb"
,
"afh_Latn"
,
"afr"
,
"akl_Latn"
,
"aln"
,
"amh"
,
"ang_Latn"
,
"apc"
,
"ara"
,
"arg"
,
"arq"
,
"ary"
,
"arz"
,
"asm"
,
"ast"
,
"avk_Latn"
,
"awa"
,
"aze_Latn"
,
"bak"
,
"bam_Latn"
,
"bel"
,
"bel_Latn"
,
"ben"
,
"bho"
,
"bod"
,
"bos_Latn"
,
"bre"
,
"brx"
,
"brx_Latn"
,
"bul"
,
"bul_Latn"
,
"cat"
,
"ceb"
,
"ces"
,
"cha"
,
"che"
,
"chr"
,
"chv"
,
"cjy_Hans"
,
"cjy_Hant"
,
"cmn"
,
"cmn_Hans"
,
"cmn_Hant"
,
"cor"
,
"cos"
,
"crh"
,
"crh_Latn"
,
"csb_Latn"
,
"cym"
,
"dan"
,
"deu"
,
"dsb"
,
"dtp"
,
"dws_Latn"
,
"egl"
,
"ell"
,
"enm_Latn"
,
"epo"
,
"est"
,
"eus"
,
"ewe"
,
"ext"
,
"fao"
,
"fij"
,
"fin"
,
"fkv_Latn"
,
"fra"
,
"frm_Latn"
,
"frr"
,
"fry"
,
"fuc"
,
"fuv"
,
"gan"
,
"gcf_Latn"
,
"gil"
,
"gla"
,
"gle"
,
"glg"
,
"glv"
,
"gom"
,
"gos"
,
"got_Goth"
,
"grc_Grek"
,
"grn"
,
"gsw"
,
"guj"
,
"hat"
,
"hau_Latn"
,
"haw"
,
"heb"
,
"hif_Latn"
,
"hil"
,
"hin"
,
"hnj_Latn"
,
"hoc"
,
"hoc_Latn"
,
"hrv"
,
"hsb"
,
"hun"
,
"hye"
,
"iba"
,
"ibo"
,
"ido"
,
"ido_Latn"
,
"ike_Latn"
,
"ile_Latn"
,
"ilo"
,
"ina_Latn"
,
"ind"
,
"isl"
,
"ita"
,
"izh"
,
"jav"
,
"jav_Java"
,
"jbo"
,
"jbo_Cyrl"
,
"jbo_Latn"
,
"jdt_Cyrl"
,
"jpn"
,
"kab"
,
"kal"
,
"kan"
,
"kat"
,
"kaz_Cyrl"
,
"kaz_Latn"
,
"kek_Latn"
,
"kha"
,
"khm"
,
"khm_Latn"
,
"kin"
,
"kir_Cyrl"
,
"kjh"
,
"kpv"
,
"krl"
,
"ksh"
,
"kum"
,
"kur_Arab"
,
"kur_Latn"
,
"lad"
,
"lad_Latn"
,
"lao"
,
"lat_Latn"
,
"lav"
,
"ldn_Latn"
,
"lfn_Cyrl"
,
"lfn_Latn"
,
"lij"
,
"lin"
,
"lit"
,
"liv_Latn"
,
"lkt"
,
"lld_Latn"
,
"lmo"
,
"ltg"
,
"ltz"
,
"lug"
,
"lzh"
,
"lzh_Hans"
,
"mad"
,
"mah"
,
"mai"
,
"mal"
,
"mar"
,
"max_Latn"
,
"mdf"
,
"mfe"
,
"mhr"
,
"mic"
,
"min"
,
"mkd"
,
"mlg"
,
"mlt"
,
"mnw"
,
"moh"
,
"mon"
,
"mri"
,
"mwl"
,
"mww"
,
"mya"
,
"myv"
,
"nan"
,
"nau"
,
"nav"
,
"nds"
,
"niu"
,
"nld"
,
"nno"
,
"nob"
,
"nob_Hebr"
,
"nog"
,
"non_Latn"
,
"nov_Latn"
,
"npi"
,
"nya"
,
"oci"
,
"ori"
,
"orv_Cyrl"
,
"oss"
,
"ota_Arab"
,
"ota_Latn"
,
"pag"
,
"pan_Guru"
,
"pap"
,
"pau"
,
"pdc"
,
"pes"
,
"pes_Latn"
,
"pes_Thaa"
,
"pms"
,
"pnb"
,
"pol"
,
"por"
,
"ppl_Latn"
,
"prg_Latn"
,
"pus"
,
"quc"
,
"qya"
,
"qya_Latn"
,
"rap"
,
"rif_Latn"
,
"roh"
,
"rom"
,
"ron"
,
"rue"
,
"run"
,
"rus"
,
"sag"
,
"sah"
,
"san_Deva"
,
"scn"
,
"sco"
,
"sgs"
,
"shs_Latn"
,
"shy_Latn"
,
"sin"
,
"sjn_Latn"
,
"slv"
,
"sma"
,
"sme"
,
"smo"
,
"sna"
,
"snd_Arab"
,
"som"
,
"spa"
,
"sqi"
,
"srp_Cyrl"
,
"srp_Latn"
,
"stq"
,
"sun"
,
"swe"
,
"swg"
,
"swh"
,
"tah"
,
"tam"
,
"tat"
,
"tat_Arab"
,
"tat_Latn"
,
"tel"
,
"tet"
,
"tgk_Cyrl"
,
"tha"
,
"tir"
,
"tlh_Latn"
,
"tly_Latn"
,
"tmw_Latn"
,
"toi_Latn"
,
"ton"
,
"tpw_Latn"
,
"tso"
,
"tuk"
,
"tuk_Latn"
,
"tur"
,
"tvl"
,
"tyv"
,
"tzl"
,
"tzl_Latn"
,
"udm"
,
"uig_Arab"
,
"uig_Cyrl"
,
"ukr"
,
"umb"
,
"urd"
,
"uzb_Cyrl"
,
"uzb_Latn"
,
"vec"
,
"vie"
,
"vie_Hani"
,
"vol_Latn"
,
"vro"
,
"war"
,
"wln"
,
"wol"
,
"wuu"
,
"xal"
,
"xho"
,
"yid"
,
"yor"
,
"yue"
,
"yue_Hans"
,
"yue_Hant"
,
"zho"
,
"zho_Hans"
,
"zho_Hant"
,
"zlm_Latn"
,
"zsm_Latn"
,
"zul"
,
"zza"
,
},
),
"nic"
:
(
"Niger-Kordofanian languages"
,
{
"bam_Latn"
,
"ewe"
,
"fuc"
,
"fuv"
,
"ibo"
,
"kin"
,
"lin"
,
"lug"
,
"nya"
,
"run"
,
"sag"
,
"sna"
,
"swh"
,
"toi_Latn"
,
"tso"
,
"umb"
,
"wol"
,
"xho"
,
"yor"
,
"zul"
,
},
),
"nld"
:
(
"Dutch"
,
{
"nld"
}),
"nor"
:
(
"Norwegian"
,
{
"nob"
,
"nno"
}),
"phi"
:
(
"Philippine languages"
,
{
"ilo"
,
"akl_Latn"
,
"war"
,
"hil"
,
"pag"
,
"ceb"
}),
"pol"
:
(
"Polish"
,
{
"pol"
}),
"por"
:
(
"Portuguese"
,
{
"por"
}),
"pqe"
:
(
"Eastern Malayo-Polynesian languages"
,
{
"fij"
,
"gil"
,
"haw"
,
"mah"
,
"mri"
,
"nau"
,
"niu"
,
"rap"
,
"smo"
,
"tah"
,
"ton"
,
"tvl"
},
),
"roa"
:
(
"Romance languages"
,
{
"arg"
,
"ast"
,
"cat"
,
"cos"
,
"egl"
,
"ext"
,
"fra"
,
"frm_Latn"
,
"gcf_Latn"
,
"glg"
,
"hat"
,
"ind"
,
"ita"
,
"lad"
,
"lad_Latn"
,
"lij"
,
"lld_Latn"
,
"lmo"
,
"max_Latn"
,
"mfe"
,
"min"
,
"mwl"
,
"oci"
,
"pap"
,
"pms"
,
"por"
,
"roh"
,
"ron"
,
"scn"
,
"spa"
,
"tmw_Latn"
,
"vec"
,
"wln"
,
"zlm_Latn"
,
"zsm_Latn"
,
},
),
"ron"
:
(
"Romanian"
,
{
"ron"
}),
"run"
:
(
"Rundi"
,
{
"run"
}),
"rus"
:
(
"Russian"
,
{
"rus"
}),
"sal"
:
(
"Salishan languages"
,
{
"shs_Latn"
}),
"sem"
:
(
"Semitic languages"
,
{
"acm"
,
"afb"
,
"amh"
,
"apc"
,
"ara"
,
"arq"
,
"ary"
,
"arz"
,
"heb"
,
"mlt"
,
"tir"
}),
"sla"
:
(
"Slavic languages"
,
{
"bel"
,
"bel_Latn"
,
"bos_Latn"
,
"bul"
,
"bul_Latn"
,
"ces"
,
"csb_Latn"
,
"dsb"
,
"hrv"
,
"hsb"
,
"mkd"
,
"orv_Cyrl"
,
"pol"
,
"rue"
,
"rus"
,
"slv"
,
"srp_Cyrl"
,
"srp_Latn"
,
"ukr"
,
},
),
"slv"
:
(
"Slovenian"
,
{
"slv"
}),
"spa"
:
(
"Spanish"
,
{
"spa"
}),
"swe"
:
(
"Swedish"
,
{
"swe"
}),
"taw"
:
(
"Tai"
,
{
"lao"
,
"tha"
}),
"tgl"
:
(
"Tagalog"
,
{
"tgl_Latn"
}),
"tha"
:
(
"Thai"
,
{
"tha"
}),
"trk"
:
(
"Turkic languages"
,
{
"aze_Latn"
,
"bak"
,
"chv"
,
"crh"
,
"crh_Latn"
,
"kaz_Cyrl"
,
"kaz_Latn"
,
"kir_Cyrl"
,
"kjh"
,
"kum"
,
"ota_Arab"
,
"ota_Latn"
,
"sah"
,
"tat"
,
"tat_Arab"
,
"tat_Latn"
,
"tuk"
,
"tuk_Latn"
,
"tur"
,
"tyv"
,
"uig_Arab"
,
"uig_Cyrl"
,
"uzb_Cyrl"
,
"uzb_Latn"
,
},
),
"tur"
:
(
"Turkish"
,
{
"tur"
}),
"ukr"
:
(
"Ukrainian"
,
{
"ukr"
}),
"urd"
:
(
"Urdu"
,
{
"urd"
}),
"urj"
:
(
"Uralic languages"
,
{
"est"
,
"fin"
,
"fkv_Latn"
,
"hun"
,
"izh"
,
"kpv"
,
"krl"
,
"liv_Latn"
,
"mdf"
,
"mhr"
,
"myv"
,
"sma"
,
"sme"
,
"udm"
,
"vep"
,
"vro"
,
},
),
"vie"
:
(
"Vietnamese"
,
{
"vie"
,
"vie_Hani"
}),
"war"
:
(
"Waray (Philippines)"
,
{
"war"
}),
"zho"
:
(
"Chinese"
,
{
"cjy_Hans"
,
"cjy_Hant"
,
"cmn"
,
"cmn_Bopo"
,
"cmn_Hang"
,
"cmn_Hani"
,
"cmn_Hans"
,
"cmn_Hant"
,
"cmn_Hira"
,
"cmn_Kana"
,
"cmn_Latn"
,
"cmn_Yiii"
,
"gan"
,
"hak_Hani"
,
"lzh"
,
"lzh_Bopo"
,
"lzh_Hang"
,
"lzh_Hani"
,
"lzh_Hans"
,
"lzh_Hira"
,
"lzh_Kana"
,
"lzh_Yiii"
,
"nan"
,
"nan_Hani"
,
"wuu"
,
"wuu_Bopo"
,
"wuu_Hani"
,
"wuu_Latn"
,
"yue"
,
"yue_Bopo"
,
"yue_Hang"
,
"yue_Hani"
,
"yue_Hans"
,
"yue_Hant"
,
"yue_Hira"
,
"yue_Kana"
,
"zho"
,
"zho_Hans"
,
"zho_Hant"
,
},
),
"zle"
:
(
"East Slavic languages"
,
{
"bel"
,
"orv_Cyrl"
,
"bel_Latn"
,
"rus"
,
"ukr"
,
"rue"
}),
"zls"
:
(
"South Slavic languages"
,
{
"bos_Latn"
,
"bul"
,
"bul_Latn"
,
"hrv"
,
"mkd"
,
"slv"
,
"srp_Cyrl"
,
"srp_Latn"
}),
"zlw"
:
(
"West Slavic languages"
,
{
"csb_Latn"
,
"dsb"
,
"hsb"
,
"pol"
,
"ces"
}),
}
def
l2front_matter
(
langs
):
return
""
.
join
(
f
"-
{
l
}
\n
"
for
l
in
langs
)
def
dedup
(
lst
):
"""Preservers order"""
new_lst
=
[]
for
item
in
lst
:
if
not
item
:
continue
elif
item
in
new_lst
:
continue
else
:
new_lst
.
append
(
item
)
return
new_lst
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-m"
,
"--models"
,
action
=
"append"
,
help
=
"<Required> Set flag"
,
required
=
True
,
nargs
=
"+"
,
dest
=
"models"
)
parser
.
add_argument
(
"-save_dir"
,
"--save_dir"
,
default
=
"marian_converted"
,
help
=
"where to save converted models"
)
args
=
parser
.
parse_args
()
resolver
=
TatoebaConverter
(
save_dir
=
args
.
save_dir
)
resolver
.
convert_models
(
args
.
models
[
0
])
src/transformers/convert_marian_to_pytorch.py
View file @
9c2b2db2
import
argparse
import
json
import
os
import
shutil
import
socket
import
time
import
warnings
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Union
from
zipfile
import
ZipFile
import
numpy
as
np
...
...
@@ -23,85 +22,6 @@ def remove_suffix(text: str, suffix: str):
return
text
# or whatever
def
_process_benchmark_table_row
(
x
):
fields
=
lmap
(
str
.
strip
,
x
.
replace
(
"
\t
"
,
""
).
split
(
"|"
)[
1
:
-
1
])
assert
len
(
fields
)
==
3
return
(
fields
[
0
],
float
(
fields
[
1
]),
float
(
fields
[
2
]))
def
process_last_benchmark_table
(
readme_path
)
->
List
[
Tuple
[
str
,
float
,
float
]]:
md_content
=
Path
(
readme_path
).
open
().
read
()
entries
=
md_content
.
split
(
"## Benchmarks"
)[
-
1
].
strip
().
split
(
"
\n
"
)[
2
:]
data
=
lmap
(
_process_benchmark_table_row
,
entries
)
return
data
def
check_if_models_are_dominated
(
old_repo_path
=
"OPUS-MT-train/models"
,
new_repo_path
=
"Tatoeba-Challenge/models/"
):
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
import
pandas
as
pd
newest_released
,
old_reg
,
released
=
get_released_df
(
new_repo_path
,
old_repo_path
)
short_to_new_bleu
=
newest_released
.
set_index
(
"short_pair"
).
bleu
assert
released
.
groupby
(
"short_pair"
).
pair
.
nunique
().
max
()
==
1
short_to_long
=
released
.
groupby
(
"short_pair"
).
pair
.
first
().
to_dict
()
overlap_short
=
old_reg
.
index
.
intersection
(
released
.
short_pair
.
unique
())
overlap_long
=
[
short_to_long
[
o
]
for
o
in
overlap_short
]
new_reported_bleu
=
[
short_to_new_bleu
[
o
]
for
o
in
overlap_short
]
def
get_old_bleu
(
o
)
->
float
:
pat
=
old_repo_path
+
"/{}/README.md"
bm_data
=
process_last_benchmark_table
(
pat
.
format
(
o
))
tab
=
pd
.
DataFrame
(
bm_data
,
columns
=
[
"testset"
,
"bleu"
,
"chr-f"
])
tato_bleu
=
tab
.
loc
[
lambda
x
:
x
.
testset
.
str
.
startswith
(
"Tato"
)].
bleu
if
tato_bleu
.
shape
[
0
]
>
0
:
return
tato_bleu
.
iloc
[
0
]
else
:
return
np
.
nan
old_bleu
=
[
get_old_bleu
(
o
)
for
o
in
overlap_short
]
cmp_df
=
pd
.
DataFrame
(
dict
(
short
=
overlap_short
,
long
=
overlap_long
,
old_bleu
=
old_bleu
,
new_bleu
=
new_reported_bleu
)
).
fillna
(
-
1
)
dominated
=
cmp_df
[
cmp_df
.
old_bleu
>
cmp_df
.
new_bleu
]
whitelist_df
=
cmp_df
[
cmp_df
.
old_bleu
<=
cmp_df
.
new_bleu
]
blacklist
=
dominated
.
long
.
unique
().
tolist
()
# 3 letter codes
return
whitelist_df
,
dominated
,
blacklist
def
get_released_df
(
new_repo_path
,
old_repo_path
):
import
pandas
as
pd
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
f
"
{
new_repo_path
}
/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
old_reg
=
make_registry
(
repo_path
=
old_repo_path
)
old_reg
=
pd
.
DataFrame
(
old_reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
old_reg
.
id
.
value_counts
().
max
()
==
1
old_reg
=
old_reg
.
set_index
(
"id"
)
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
)))
newest_released
=
released
.
dsort
(
"date"
).
drop_duplicates
([
"short_pair"
],
keep
=
"first"
)
return
newest_released
,
old_reg
,
released
def
remove_prefix
(
text
:
str
,
prefix
:
str
):
if
text
.
startswith
(
prefix
):
return
text
[
len
(
prefix
)
:]
...
...
@@ -183,7 +103,11 @@ def find_model_file(dest_dir): # this one better
# Group Names Logic: change long opus model names to something shorter, like opus-mt-en-ROMANCE
ROM_GROUP
=
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
ROM_GROUP
=
(
"fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO+es_EC+es_ES+es_GT"
"+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR+pt_PT+gl+lad+an+mwl+it+it_IT+co"
"+nap+scn+vec+sc+ro+la"
)
GROUPS
=
[
(
"cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
,
"ZH"
),
(
ROM_GROUP
,
"ROMANCE"
),
...
...
@@ -221,13 +145,15 @@ ORG_NAME = "Helsinki-NLP/"
def
convert_opus_name_to_hf_name
(
x
):
"""For OPUS-MT-Train/ DEPRECATED"""
for
substr
,
grp_name
in
GROUPS
:
x
=
x
.
replace
(
substr
,
grp_name
)
return
x
.
replace
(
"+"
,
"_"
)
def
convert_hf_name_to_opus_name
(
hf_model_name
):
"""Relies on the assumption that there are no language codes like pt_br in models that are not in GROUP_TO_OPUS_NAME."""
"""Relies on the assumption that there are no language codes like pt_br in models that are not in
GROUP_TO_OPUS_NAME."""
hf_model_name
=
remove_prefix
(
hf_model_name
,
ORG_NAME
)
if
hf_model_name
in
GROUP_TO_OPUS_NAME
:
opus_w_prefix
=
GROUP_TO_OPUS_NAME
[
hf_model_name
]
...
...
@@ -247,8 +173,9 @@ def get_system_metadata(repo_root):
)
front_matter
=
"""---
language: {}
FRONT_MATTER_TEMPLATE
=
"""---
language:
{}
tags:
- translation
...
...
@@ -256,11 +183,13 @@ license: apache-2.0
---
"""
DEFAULT_REPO
=
"Tatoeba-Challenge"
DEFAULT_MODEL_DIR
=
os
.
path
.
join
(
DEFAULT_REPO
,
"models"
)
def
write_model_card
(
hf_model_name
:
str
,
repo_root
=
"OPUS-MT-train"
,
repo_root
=
DEFAULT_REPO
,
save_dir
=
Path
(
"marian_converted"
),
dry_run
=
False
,
extra_metadata
=
{},
...
...
@@ -294,7 +223,10 @@ def write_model_card(
# combine with opus markdown
extra_markdown
=
f
"###
{
hf_model_name
}
\n\n
* source group:
{
metadata
[
'src_name'
]
}
\n
* target group:
{
metadata
[
'tgt_name'
]
}
\n
* OPUS readme: [
{
opus_name
}
](
{
readme_url
}
)
\n
"
extra_markdown
=
(
f
"###
{
hf_model_name
}
\n\n
* source group:
{
metadata
[
'src_name'
]
}
\n
* target group: "
f
"
{
metadata
[
'tgt_name'
]
}
\n
* OPUS readme: [
{
opus_name
}
](
{
readme_url
}
)
\n
"
)
content
=
opus_readme_path
.
open
().
read
()
content
=
content
.
split
(
"
\n
# "
)[
-
1
]
# Get the lowest level 1 header in the README -- the most recent model.
...
...
@@ -302,7 +234,7 @@ def write_model_card(
print
(
splat
[
3
])
content
=
"*"
.
join
(
splat
)
content
=
(
front_matter
.
format
(
metadata
[
"src_alpha2"
])
FRONT_MATTER_TEMPLATE
.
format
(
metadata
[
"src_alpha2"
])
+
extra_markdown
+
"
\n
* "
+
content
.
replace
(
"download"
,
"download original weights"
)
...
...
@@ -323,48 +255,6 @@ def write_model_card(
return
content
,
metadata
def
get_clean_model_id_mapping
(
multiling_model_ids
):
return
{
x
:
convert_opus_name_to_hf_name
(
x
)
for
x
in
multiling_model_ids
}
def
expand_group_to_two_letter_codes
(
grp_name
):
raise
NotImplementedError
()
def
get_two_letter_code
(
three_letter_code
):
raise
NotImplementedError
()
# return two_letter_code
def
get_tags
(
code
,
ref_name
):
if
len
(
code
)
==
2
:
assert
"languages"
not
in
ref_name
,
f
"
{
code
}
:
{
ref_name
}
"
return
[
code
],
False
elif
"languages"
in
ref_name
:
group
=
expand_group_to_two_letter_codes
(
code
)
group
.
append
(
code
)
return
group
,
True
else
:
# zho-> zh
raise
ValueError
(
f
"Three letter monolingual code:
{
code
}
"
)
def
resolve_lang_code
(
r
):
"""R is a row in ported"""
short_pair
=
r
.
short_pair
src
,
tgt
=
short_pair
.
split
(
"-"
)
src_tags
,
src_multilingual
=
get_tags
(
src
,
r
.
src_name
)
assert
isinstance
(
src_tags
,
list
)
tgt_tags
,
tgt_multilingual
=
get_tags
(
src
,
r
.
tgt_name
)
assert
isinstance
(
tgt_tags
,
list
)
if
src_multilingual
:
src_tags
.
append
(
"multilingual_src"
)
if
tgt_multilingual
:
tgt_tags
.
append
(
"multilingual_tgt"
)
return
src_tags
+
tgt_tags
# process target
def
make_registry
(
repo_path
=
"Opus-MT-train/models"
):
if
not
(
Path
(
repo_path
)
/
"fr-en"
/
"README.md"
).
exists
():
raise
ValueError
(
...
...
@@ -382,36 +272,25 @@ def make_registry(repo_path="Opus-MT-train/models"):
return
[(
k
,
v
[
"pre-processing"
],
v
[
"download"
],
v
[
"download"
][:
-
4
]
+
".test.txt"
)
for
k
,
v
in
results
.
items
()]
def
make_tatoeba_registry
(
repo_path
=
"Tatoeba-Challenge/models"
):
if
not
(
Path
(
repo_path
)
/
"zho-eng"
/
"README.md"
).
exists
():
raise
ValueError
(
f
"repo_path:
{
repo_path
}
does not exist: "
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
)
results
=
{}
for
p
in
Path
(
repo_path
).
iterdir
():
if
len
(
p
.
name
)
!=
7
:
continue
lns
=
list
(
open
(
p
/
"README.md"
).
readlines
())
results
[
p
.
name
]
=
_parse_readme
(
lns
)
return
[(
k
,
v
[
"pre-processing"
],
v
[
"download"
],
v
[
"download"
][:
-
4
]
+
".test.txt"
)
for
k
,
v
in
results
.
items
()]
def
convert_all_sentencepiece_models
(
model_list
=
None
,
repo_path
=
None
):
def
convert_all_sentencepiece_models
(
model_list
=
None
,
repo_path
=
None
,
dest_dir
=
Path
(
"marian_converted"
)):
"""Requires 300GB"""
save_dir
=
Path
(
"marian_ckpt"
)
dest_dir
=
Path
(
"marian_converted"
)
dest_dir
=
Path
(
dest_dir
)
dest_dir
.
mkdir
(
exist_ok
=
True
)
save_paths
=
[]
if
model_list
is
None
:
model_list
:
list
=
make_registry
(
repo_path
=
repo_path
)
for
k
,
prepro
,
download
,
test_set_url
in
tqdm
(
model_list
):
if
"SentencePiece"
not
in
prepro
:
# dont convert BPE models.
continue
if
not
os
.
path
.
exists
(
save_dir
/
k
/
"pytorch_model.bin"
):
if
not
os
.
path
.
exists
(
save_dir
/
k
):
download_and_unzip
(
download
,
save_dir
/
k
)
pair_name
=
convert_opus_name_to_hf_name
(
k
)
convert
(
save_dir
/
k
,
dest_dir
/
f
"opus-mt-
{
pair_name
}
"
)
save_paths
.
append
(
dest_dir
/
f
"opus-mt-
{
pair_name
}
"
)
return
save_paths
def
lmap
(
f
,
x
)
->
List
:
return
list
(
map
(
f
,
x
))
...
...
@@ -493,15 +372,6 @@ def add_special_tokens_to_vocab(model_dir: Path) -> None:
save_tokenizer_config
(
model_dir
)
def
save_tokenizer
(
self
,
save_directory
):
dest
=
Path
(
save_directory
)
src_path
=
Path
(
self
.
init_kwargs
[
"source_spm"
])
for
dest_name
in
{
"source.spm"
,
"target.spm"
,
"tokenizer_config.json"
}:
shutil
.
copyfile
(
src_path
.
parent
/
dest_name
,
dest
/
dest_name
)
save_json
(
self
.
encoder
,
dest
/
"vocab.json"
)
def
check_equal
(
marian_cfg
,
k1
,
k2
):
v1
,
v2
=
marian_cfg
[
k1
],
marian_cfg
[
k2
]
assert
v1
==
v2
,
f
"hparams
{
k1
}
,
{
k2
}
differ:
{
v1
}
!=
{
v2
}
"
...
...
@@ -698,14 +568,14 @@ def convert(source_dir: Path, dest_dir):
add_special_tokens_to_vocab
(
source_dir
)
tokenizer
=
MarianTokenizer
.
from_pretrained
(
str
(
source_dir
))
save_
tokenizer
(
tokenizer
,
dest_dir
)
tokenizer
.
save_pretrained
(
dest_dir
)
opus_state
=
OpusState
(
source_dir
)
assert
opus_state
.
cfg
[
"vocab_size"
]
==
len
(
tokenizer
.
encoder
),
f
"Original vocab size
{
opus_state
.
cfg
[
'vocab_size'
]
}
and new vocab size
{
len
(
tokenizer
.
encoder
)
}
mismatched"
# save_json(opus_state.cfg, dest_dir / "marian_original_config.json")
# ^^
S
ave human readable marian config for debugging
# ^^
Uncomment to s
ave human readable marian config for debugging
model
=
opus_state
.
load_marian_model
()
model
=
model
.
half
()
...
...
@@ -732,15 +602,11 @@ def unzip(zip_path: str, dest_dir: str) -> None:
if
__name__
==
"__main__"
:
"""
To bulk convert, run
>>> from transformers.convert_marian_to_pytorch import make_tatoeba_registry, convert_all_sentencepiece_models
>>> reg = make_tatoeba_registry()
>>> convert_all_sentencepiece_models(model_list=reg) # saves to marian_converted
(bash) aws s3 sync marian_converted s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
Tatoeba conversion instructions in scripts/tatoeba/README.md
"""
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--src"
,
type
=
str
,
help
=
"path to marian model dir"
,
default
=
"en-de"
)
parser
.
add_argument
(
"--src"
,
type
=
str
,
help
=
"path to marian model
sub
dir"
,
default
=
"en-de"
)
parser
.
add_argument
(
"--dest"
,
type
=
str
,
default
=
None
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment