Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
38f17037
Unverified
Commit
38f17037
authored
Sep 23, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 23, 2020
Browse files
wip: Code to add lang tags to marian model cards (#6586)
parent
129fdae0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
28 deletions
+71
-28
src/transformers/convert_marian_to_pytorch.py
src/transformers/convert_marian_to_pytorch.py
+71
-28
No files found.
src/transformers/convert_marian_to_pytorch.py
View file @
38f17037
...
...
@@ -40,33 +40,7 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
import
pandas
as
pd
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
f
"
{
new_repo_path
}
/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
old_reg
=
make_registry
(
repo_path
=
old_repo_path
)
old_reg
=
pd
.
DataFrame
(
old_reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
old_reg
.
id
.
value_counts
().
max
()
==
1
old_reg
=
old_reg
.
set_index
(
"id"
)
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
)))
newest_released
=
released
.
dsort
(
"date"
).
drop_duplicates
([
"short_pair"
],
keep
=
"first"
)
newest_released
,
old_reg
,
released
=
get_released_df
(
new_repo_path
,
old_repo_path
)
short_to_new_bleu
=
newest_released
.
set_index
(
"short_pair"
).
bleu
...
...
@@ -94,8 +68,38 @@ def check_if_models_are_dominated(old_repo_path="OPUS-MT-train/models", new_repo
).
fillna
(
-
1
)
dominated
=
cmp_df
[
cmp_df
.
old_bleu
>
cmp_df
.
new_bleu
]
whitelist_df
=
cmp_df
[
cmp_df
.
old_bleu
<=
cmp_df
.
new_bleu
]
blacklist
=
dominated
.
long
.
unique
().
tolist
()
# 3 letter codes
return
dominated
,
blacklist
return
whitelist_df
,
dominated
,
blacklist
def
get_released_df
(
new_repo_path
,
old_repo_path
):
import
pandas
as
pd
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
f
"
{
new_repo_path
}
/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
old_reg
=
make_registry
(
repo_path
=
old_repo_path
)
old_reg
=
pd
.
DataFrame
(
old_reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
old_reg
.
id
.
value_counts
().
max
()
==
1
old_reg
=
old_reg
.
set_index
(
"id"
)
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
)))
newest_released
=
released
.
dsort
(
"date"
).
drop_duplicates
([
"short_pair"
],
keep
=
"first"
)
return
newest_released
,
old_reg
,
released
def
remove_prefix
(
text
:
str
,
prefix
:
str
):
...
...
@@ -323,6 +327,44 @@ def get_clean_model_id_mapping(multiling_model_ids):
return
{
x
:
convert_opus_name_to_hf_name
(
x
)
for
x
in
multiling_model_ids
}
def
expand_group_to_two_letter_codes
(
grp_name
):
raise
NotImplementedError
()
def
get_two_letter_code
(
three_letter_code
):
raise
NotImplementedError
()
# return two_letter_code
def
get_tags
(
code
,
ref_name
):
if
len
(
code
)
==
2
:
assert
"languages"
not
in
ref_name
,
f
"
{
code
}
:
{
ref_name
}
"
return
[
code
],
False
elif
"languages"
in
ref_name
:
group
=
expand_group_to_two_letter_codes
(
code
)
group
.
append
(
code
)
return
group
,
True
else
:
# zho-> zh
raise
ValueError
(
f
"Three letter monolingual code:
{
code
}
"
)
def
resolve_lang_code
(
r
):
"""R is a row in ported"""
short_pair
=
r
.
short_pair
src
,
tgt
=
short_pair
.
split
(
"-"
)
src_tags
,
src_multilingual
=
get_tags
(
src
,
r
.
src_name
)
assert
isinstance
(
src_tags
,
list
)
tgt_tags
,
tgt_multilingual
=
get_tags
(
src
,
r
.
tgt_name
)
assert
isinstance
(
tgt_tags
,
list
)
if
src_multilingual
:
src_tags
.
append
(
"multilingual_src"
)
if
tgt_multilingual
:
tgt_tags
.
append
(
"multilingual_tgt"
)
return
src_tags
+
tgt_tags
# process target
def
make_registry
(
repo_path
=
"Opus-MT-train/models"
):
if
not
(
Path
(
repo_path
)
/
"fr-en"
/
"README.md"
).
exists
():
raise
ValueError
(
...
...
@@ -666,6 +708,7 @@ def convert(source_dir: Path, dest_dir):
# ^^ Save human readable marian config for debugging
model
=
opus_state
.
load_marian_model
()
model
=
model
.
half
()
model
.
save_pretrained
(
dest_dir
)
model
.
from_pretrained
(
dest_dir
)
# sanity check
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment