Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
38f17037
Unverified
Commit
38f17037
authored
Sep 23, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 23, 2020
Browse files
wip: Code to add lang tags to marian model cards (#6586)
parent
129fdae0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
28 deletions
+71
-28
src/transformers/convert_marian_to_pytorch.py
src/transformers/convert_marian_to_pytorch.py
+71
-28
No files found.
src/transformers/convert_marian_to_pytorch.py
View file @
38f17037
...
...
@@ -40,33 +40,7 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
import
pandas
as
pd
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
f
"
{
new_repo_path
}
/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
old_reg
=
make_registry
(
repo_path
=
old_repo_path
)
old_reg
=
pd
.
DataFrame
(
old_reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
old_reg
.
id
.
value_counts
().
max
()
==
1
old_reg
=
old_reg
.
set_index
(
"id"
)
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
)))
newest_released
=
released
.
dsort
(
"date"
).
drop_duplicates
([
"short_pair"
],
keep
=
"first"
)
newest_released
,
old_reg
,
released
=
get_released_df
(
new_repo_path
,
old_repo_path
)
short_to_new_bleu
=
newest_released
.
set_index
(
"short_pair"
).
bleu
...
...
@@ -94,8 +68,38 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
).
fillna
(
-
1
)
dominated
=
cmp_df
[
cmp_df
.
old_bleu
>
cmp_df
.
new_bleu
]
whitelist_df
=
cmp_df
[
cmp_df
.
old_bleu
<=
cmp_df
.
new_bleu
]
blacklist
=
dominated
.
long
.
unique
().
tolist
()
# 3 letter codes
return
dominated
,
blacklist
return
whitelist_df
,
dominated
,
blacklist
def
get_released_df
(
new_repo_path
,
old_repo_path
):
import
pandas
as
pd
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
f
"
{
new_repo_path
}
/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
old_reg
=
make_registry
(
repo_path
=
old_repo_path
)
old_reg
=
pd
.
DataFrame
(
old_reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
old_reg
.
id
.
value_counts
().
max
()
==
1
old_reg
=
old_reg
.
set_index
(
"id"
)
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
)))
newest_released
=
released
.
dsort
(
"date"
).
drop_duplicates
([
"short_pair"
],
keep
=
"first"
)
return
newest_released
,
old_reg
,
released
def
remove_prefix
(
text
:
str
,
prefix
:
str
):
...
...
@@ -323,6 +327,44 @@ def get_clean_model_id_mapping(multiling_model_ids):
return
{
x
:
convert_opus_name_to_hf_name
(
x
)
for
x
in
multiling_model_ids
}
def
expand_group_to_two_letter_codes
(
grp_name
):
raise
NotImplementedError
()
def
get_two_letter_code
(
three_letter_code
):
raise
NotImplementedError
()
# return two_letter_code
def
get_tags
(
code
,
ref_name
):
if
len
(
code
)
==
2
:
assert
"languages"
not
in
ref_name
,
f
"
{
code
}
:
{
ref_name
}
"
return
[
code
],
False
elif
"languages"
in
ref_name
:
group
=
expand_group_to_two_letter_codes
(
code
)
group
.
append
(
code
)
return
group
,
True
else
:
# zho-> zh
raise
ValueError
(
f
"Three letter monolingual code:
{
code
}
"
)
def
resolve_lang_code
(
r
):
"""R is a row in ported"""
short_pair
=
r
.
short_pair
src
,
tgt
=
short_pair
.
split
(
"-"
)
src_tags
,
src_multilingual
=
get_tags
(
src
,
r
.
src_name
)
assert
isinstance
(
src_tags
,
list
)
tgt_tags
,
tgt_multilingual
=
get_tags
(
src
,
r
.
tgt_name
)
assert
isinstance
(
tgt_tags
,
list
)
if
src_multilingual
:
src_tags
.
append
(
"multilingual_src"
)
if
tgt_multilingual
:
tgt_tags
.
append
(
"multilingual_tgt"
)
return
src_tags
+
tgt_tags
# process target
def
make_registry
(
repo_path
=
"Opus-MT-train/models"
):
if
not
(
Path
(
repo_path
)
/
"fr-en"
/
"README.md"
).
exists
():
raise
ValueError
(
...
...
@@ -666,6 +708,7 @@ def convert(source_dir: Path, dest_dir):
# ^^ Save human readable marian config for debugging
model
=
opus_state
.
load_marian_model
()
model
=
model
.
half
()
model
.
save_pretrained
(
dest_dir
)
model
.
from_pretrained
(
dest_dir
)
# sanity check
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment