Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12d76241
Unverified
Commit
12d76241
authored
Aug 17, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 17, 2020
Browse files
[marian] converter supports models from new Tatoeba project (#6342)
parent
fb7330b3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
200 additions
and
34 deletions
+200
-34
docs/source/model_doc/marian.rst
docs/source/model_doc/marian.rst
+3
-3
src/transformers/convert_marian_to_pytorch.py
src/transformers/convert_marian_to_pytorch.py
+186
-31
tests/test_modeling_marian.py
tests/test_modeling_marian.py
+11
-0
No files found.
docs/source/model_doc/marian.rst
View file @
12d76241
MarianMT
MarianMT
----------------------------------------------------
----------------------------------------------------
**
DISCLAIMER
:**
If
you
see
something
strange
,
**
Bugs
:**
If
you
see
something
strange
,
file
a
`
Github
Issue
<
https
://
github
.
com
/
huggingface
/
transformers
/
issues
/
new
?
assignees
=&
labels
=&
template
=
bug
-
report
.
md
&
title
>`
__
and
assign
file
a
`
Github
Issue
<
https
://
github
.
com
/
huggingface
/
transformers
/
issues
/
new
?
assignees
=
sshleifer
&
labels
=&
template
=
bug
-
report
.
md
&
title
>`
__
and
assign
@
sshleifer
.
Translations
should
be
similar
,
but
not
identical
to
,
output
in
the
test
set
linked
to
in
each
model
card
.
@
sshleifer
.
Translations
should
be
similar
,
but
not
identical
to
,
output
in
the
test
set
linked
to
in
each
model
card
.
Implementation
Notes
Implementation
Notes
~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~
-
Each
model
is
about
298
MB
on
disk
,
there
are
1
,
000
+
models
.
-
Each
model
is
about
298
MB
on
disk
,
there
are
1
,
000
+
models
.
-
The
list
of
supported
language
pairs
can
be
found
`
here
<
https
://
huggingface
.
co
/
Helsinki
-
NLP
>`
__
.
-
The
list
of
supported
language
pairs
can
be
found
`
here
<
https
://
huggingface
.
co
/
Helsinki
-
NLP
>`
__
.
-
The
1
,
000
+
models
were
originally
trained
by
`
J
ö
rg
Tiedemann
<
https
://
researchportal
.
helsinki
.
fi
/
en
/
persons
/
j
%
C3
%
B6rg
-
tiedemann
>`
__
using
the
`
Marian
<
https
://
marian
-
nmt
.
github
.
io
/>`
_
C
++
library
,
which
supports
fast
training
and
translation
.
-
models
were
originally
trained
by
`
J
ö
rg
Tiedemann
<
https
://
researchportal
.
helsinki
.
fi
/
en
/
persons
/
j
%
C3
%
B6rg
-
tiedemann
>`
__
using
the
`
Marian
<
https
://
marian
-
nmt
.
github
.
io
/>`
_
C
++
library
,
which
supports
fast
training
and
translation
.
-
All
models
are
transformer
encoder
-
decoders
with
6
layers
in
each
component
.
Each
model
's performance is documented in a model card.
-
All
models
are
transformer
encoder
-
decoders
with
6
layers
in
each
component
.
Each
model
's performance is documented in a model card.
- The 80 opus models that require BPE preprocessing are not supported.
- The 80 opus models that require BPE preprocessing are not supported.
- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
- The modeling code is the same as ``BartForConditionalGeneration`` with a few minor modifications:
...
...
src/transformers/convert_marian_to_pytorch.py
View file @
12d76241
...
@@ -2,9 +2,11 @@ import argparse
...
@@ -2,9 +2,11 @@ import argparse
import
json
import
json
import
os
import
os
import
shutil
import
shutil
import
socket
import
time
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
zipfile
import
ZipFile
from
zipfile
import
ZipFile
import
numpy
as
np
import
numpy
as
np
...
@@ -15,6 +17,87 @@ from transformers import MarianConfig, MarianMTModel, MarianTokenizer
...
@@ -15,6 +17,87 @@ from transformers import MarianConfig, MarianMTModel, MarianTokenizer
from
transformers.hf_api
import
HfApi
from
transformers.hf_api
import
HfApi
def
remove_suffix
(
text
:
str
,
suffix
:
str
):
if
text
.
endswith
(
suffix
):
return
text
[:
-
len
(
suffix
)]
return
text
# or whatever
def
_process_benchmark_table_row
(
x
):
fields
=
lmap
(
str
.
strip
,
x
.
replace
(
"
\t
"
,
""
).
split
(
"|"
)[
1
:
-
1
])
assert
len
(
fields
)
==
3
return
(
fields
[
0
],
float
(
fields
[
1
]),
float
(
fields
[
2
]))
def
process_last_benchmark_table
(
readme_path
)
->
List
[
Tuple
[
str
,
float
,
float
]]:
md_content
=
Path
(
readme_path
).
open
().
read
()
entries
=
md_content
.
split
(
"## Benchmarks"
)[
-
1
].
strip
().
split
(
"
\n
"
)[
2
:]
data
=
lmap
(
_process_benchmark_table_row
,
entries
)
return
data
def
check_if_models_are_dominated
(
old_repo_path
=
"OPUS-MT-train/models"
,
new_repo_path
=
"Tatoeba-Challenge/models/"
):
"""Make a blacklist for models where we have already ported the same language pair, and the ported model has higher BLEU score."""
import
pandas
as
pd
released_cols
=
[
"url_base"
,
"pair"
,
# (ISO639-3/ISO639-5 codes),
"short_pair"
,
# (reduced codes),
"chrF2_score"
,
"bleu"
,
"brevity_penalty"
,
"ref_len"
,
"src_name"
,
"tgt_name"
,
]
released
=
pd
.
read_csv
(
f
"
{
new_repo_path
}
/released-models.txt"
,
sep
=
"
\t
"
,
header
=
None
).
iloc
[:
-
1
]
released
.
columns
=
released_cols
old_reg
=
make_registry
(
repo_path
=
old_repo_path
)
old_reg
=
pd
.
DataFrame
(
old_reg
,
columns
=
[
"id"
,
"prepro"
,
"url_model"
,
"url_test_set"
])
assert
old_reg
.
id
.
value_counts
().
max
()
==
1
old_reg
=
old_reg
.
set_index
(
"id"
)
released
[
"fname"
]
=
released
[
"url_base"
].
apply
(
lambda
x
:
remove_suffix
(
remove_prefix
(
x
,
"https://object.pouta.csc.fi/Tatoeba-Challenge/opus"
),
".zip"
)
)
released
[
"2m"
]
=
released
.
fname
.
str
.
startswith
(
"2m"
)
released
[
"date"
]
=
pd
.
to_datetime
(
released
[
"fname"
].
apply
(
lambda
x
:
remove_prefix
(
remove_prefix
(
x
,
"2m-"
),
"-"
)))
newest_released
=
released
.
dsort
(
"date"
).
drop_duplicates
([
"short_pair"
],
keep
=
"first"
)
short_to_new_bleu
=
newest_released
.
set_index
(
"short_pair"
).
bleu
assert
released
.
groupby
(
"short_pair"
).
pair
.
nunique
().
max
()
==
1
short_to_long
=
released
.
groupby
(
"short_pair"
).
pair
.
first
().
to_dict
()
overlap_short
=
old_reg
.
index
.
intersection
(
released
.
short_pair
.
unique
())
overlap_long
=
[
short_to_long
[
o
]
for
o
in
overlap_short
]
new_reported_bleu
=
[
short_to_new_bleu
[
o
]
for
o
in
overlap_short
]
def
get_old_bleu
(
o
)
->
float
:
pat
=
old_repo_path
+
"/{}/README.md"
bm_data
=
process_last_benchmark_table
(
pat
.
format
(
o
))
tab
=
pd
.
DataFrame
(
bm_data
,
columns
=
[
"testset"
,
"bleu"
,
"chr-f"
])
tato_bleu
=
tab
.
loc
[
lambda
x
:
x
.
testset
.
str
.
startswith
(
"Tato"
)].
bleu
if
tato_bleu
.
shape
[
0
]
>
0
:
return
tato_bleu
.
iloc
[
0
]
else
:
return
np
.
nan
old_bleu
=
[
get_old_bleu
(
o
)
for
o
in
overlap_short
]
cmp_df
=
pd
.
DataFrame
(
dict
(
short
=
overlap_short
,
long
=
overlap_long
,
old_bleu
=
old_bleu
,
new_bleu
=
new_reported_bleu
)
).
fillna
(
-
1
)
dominated
=
cmp_df
[
cmp_df
.
old_bleu
>
cmp_df
.
new_bleu
]
blacklist
=
dominated
.
long
.
unique
().
tolist
()
# 3 letter codes
return
dominated
,
blacklist
def
remove_prefix
(
text
:
str
,
prefix
:
str
):
def
remove_prefix
(
text
:
str
,
prefix
:
str
):
if
text
.
startswith
(
prefix
):
if
text
.
startswith
(
prefix
):
return
text
[
len
(
prefix
)
:]
return
text
[
len
(
prefix
)
:]
...
@@ -149,37 +232,87 @@ def convert_hf_name_to_opus_name(hf_model_name):
...
@@ -149,37 +232,87 @@ def convert_hf_name_to_opus_name(hf_model_name):
return
remove_prefix
(
opus_w_prefix
,
"opus-mt-"
)
return
remove_prefix
(
opus_w_prefix
,
"opus-mt-"
)
def
get_system_metadata
(
repo_root
):
import
git
return
dict
(
helsinki_git_sha
=
git
.
Repo
(
path
=
repo_root
,
search_parent_directories
=
True
).
head
.
object
.
hexsha
,
transformers_git_sha
=
git
.
Repo
(
path
=
"."
,
search_parent_directories
=
True
).
head
.
object
.
hexsha
,
port_machine
=
socket
.
gethostname
(),
port_time
=
time
.
strftime
(
"%Y-%m-%d-%H:%M"
),
)
front_matter
=
"""---
language: {}
tags:
- translation
license: apache-2.0
---
"""
def
write_model_card
(
def
write_model_card
(
hf_model_name
:
str
,
hf_model_name
:
str
,
repo_root
=
"OPUS-MT-train"
,
save_dir
=
Path
(
"marian_converted"
),
dry_run
=
False
,
extra_metadata
=
{},
repo_path
=
"OPUS-MT-train/models/"
,
dry_run
=
False
,
model_card_dir
=
Path
(
"marian_converted/model_cards/Helsinki-NLP/"
),
)
->
str
:
)
->
str
:
"""Copy the most recent model's readme section from opus, and add metadata.
"""Copy the most recent model's readme section from opus, and add metadata.
upload command:
s3cmd sync --recursive
model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
upload command:
aws s3 sync
model_card_dir s3://models.huggingface.co/bert/Helsinki-NLP/
--dryrun
"""
"""
import
pandas
as
pd
hf_model_name
=
remove_prefix
(
hf_model_name
,
ORG_NAME
)
hf_model_name
=
remove_prefix
(
hf_model_name
,
ORG_NAME
)
opus_name
:
str
=
convert_hf_name_to_opus_name
(
hf_model_name
)
opus_name
:
str
=
convert_hf_name_to_opus_name
(
hf_model_name
)
assert
repo_root
in
(
"OPUS-MT-train"
,
"Tatoeba-Challenge"
)
opus_readme_path
=
Path
(
repo_root
).
joinpath
(
"models"
,
opus_name
,
"README.md"
)
assert
opus_readme_path
.
exists
(),
f
"Readme file
{
opus_readme_path
}
not found"
opus_src
,
opus_tgt
=
[
x
.
split
(
"+"
)
for
x
in
opus_name
.
split
(
"-"
)]
opus_src
,
opus_tgt
=
[
x
.
split
(
"+"
)
for
x
in
opus_name
.
split
(
"-"
)]
readme_url
=
OPUS_GITHUB_URL
+
f
"
{
opus_name
}
/README.md"
readme_url
=
f
"https://github.com/Helsinki-NLP/
{
repo_root
}
/tree/master/models/
{
opus_name
}
/README.md"
s
,
t
=
","
.
join
(
opus_src
),
","
.
join
(
opus_tgt
)
s
,
t
=
","
.
join
(
opus_src
),
","
.
join
(
opus_tgt
)
extra_markdown
=
f
"###
{
hf_model_name
}
\n\n
* source languages:
{
s
}
\n
* target languages:
{
t
}
\n
* OPUS readme: [
{
opus_name
}
](
{
readme_url
}
)
\n
"
metadata
=
{
"hf_name"
:
hf_model_name
,
"source_languages"
:
s
,
"target_languages"
:
t
,
"opus_readme_url"
:
readme_url
,
"original_repo"
:
repo_root
,
"tags"
:
[
"translation"
],
}
metadata
.
update
(
extra_metadata
)
metadata
.
update
(
get_system_metadata
(
repo_root
))
# combine with opus markdown
# combine with opus markdown
opus_readme_path
=
Path
(
f
"
{
repo_path
}{
opus_name
}
/README.md"
)
assert
opus_readme_path
.
exists
(),
f
"Readme file
{
opus_readme_path
}
not found"
extra_markdown
=
f
"###
{
hf_model_name
}
\n\n
* source group:
{
metadata
[
'src_name'
]
}
\n
* target group:
{
metadata
[
'tgt_name'
]
}
\n
* OPUS readme: [
{
opus_name
}
](
{
readme_url
}
)
\n
"
content
=
opus_readme_path
.
open
().
read
()
content
=
opus_readme_path
.
open
().
read
()
content
=
content
.
split
(
"
\n
# "
)[
-
1
]
# Get the lowest level 1 header in the README -- the most recent model.
content
=
content
.
split
(
"
\n
# "
)[
-
1
]
# Get the lowest level 1 header in the README -- the most recent model.
content
=
"*"
.
join
(
content
.
split
(
"*"
)[
1
:])
splat
=
content
.
split
(
"*"
)[
2
:]
content
=
extra_markdown
+
"
\n
* "
+
content
.
replace
(
"download"
,
"download original weights"
)
print
(
splat
[
3
])
content
=
"*"
.
join
(
splat
)
content
=
(
front_matter
.
format
(
metadata
[
"src_alpha2"
])
+
extra_markdown
+
"
\n
* "
+
content
.
replace
(
"download"
,
"download original weights"
)
)
items
=
"
\n\n
"
.
join
([
f
"-
{
k
}
:
{
v
}
"
for
k
,
v
in
metadata
.
items
()])
sec3
=
"
\n
### System Info:
\n
"
+
items
content
+=
sec3
if
dry_run
:
if
dry_run
:
return
content
return
content
,
metadata
# Save string to model_cards/hf_model_name/readme.md
sub_dir
=
save_dir
/
f
"opus-mt-
{
hf_model_name
}
"
model_card_dir
.
mkdir
(
exist_ok
=
True
)
sub_dir
=
model_card_dir
/
hf_model_name
sub_dir
.
mkdir
(
exist_ok
=
True
)
sub_dir
.
mkdir
(
exist_ok
=
True
)
dest
=
sub_dir
/
"README.md"
dest
=
sub_dir
/
"README.md"
dest
.
open
(
"w"
).
write
(
content
)
dest
.
open
(
"w"
).
write
(
content
)
return
content
pd
.
Series
(
metadata
).
to_json
(
sub_dir
/
"metadata.json"
)
# if dry_run:
return
content
,
metadata
def
get_clean_model_id_mapping
(
multiling_model_ids
):
def
get_clean_model_id_mapping
(
multiling_model_ids
):
...
@@ -193,7 +326,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
...
@@ -193,7 +326,7 @@ def make_registry(repo_path="Opus-MT-train/models"):
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
"You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
)
)
results
=
{}
results
=
{}
for
p
in
Path
(
repo_path
).
ls
():
for
p
in
Path
(
repo_path
).
iterdir
():
n_dash
=
p
.
name
.
count
(
"-"
)
n_dash
=
p
.
name
.
count
(
"-"
)
if
n_dash
==
0
:
if
n_dash
==
0
:
continue
continue
...
@@ -203,6 +336,21 @@ def make_registry(repo_path="Opus-MT-train/models"):
...
@@ -203,6 +336,21 @@ def make_registry(repo_path="Opus-MT-train/models"):
return
[(
k
,
v
[
"pre-processing"
],
v
[
"download"
],
v
[
"download"
][:
-
4
]
+
".test.txt"
)
for
k
,
v
in
results
.
items
()]
return
[(
k
,
v
[
"pre-processing"
],
v
[
"download"
],
v
[
"download"
][:
-
4
]
+
".test.txt"
)
for
k
,
v
in
results
.
items
()]
def
make_tatoeba_registry
(
repo_path
=
"Tatoeba-Challenge/models"
):
if
not
(
Path
(
repo_path
)
/
"zho-eng"
/
"README.md"
).
exists
():
raise
ValueError
(
f
"repo_path:
{
repo_path
}
does not exist: "
"You must run: git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git before calling."
)
results
=
{}
for
p
in
Path
(
repo_path
).
iterdir
():
if
len
(
p
.
name
)
!=
7
:
continue
lns
=
list
(
open
(
p
/
"README.md"
).
readlines
())
results
[
p
.
name
]
=
_parse_readme
(
lns
)
return
[(
k
,
v
[
"pre-processing"
],
v
[
"download"
],
v
[
"download"
][:
-
4
]
+
".test.txt"
)
for
k
,
v
in
results
.
items
()]
def
convert_all_sentencepiece_models
(
model_list
=
None
,
repo_path
=
None
):
def
convert_all_sentencepiece_models
(
model_list
=
None
,
repo_path
=
None
):
"""Requires 300GB"""
"""Requires 300GB"""
save_dir
=
Path
(
"marian_ckpt"
)
save_dir
=
Path
(
"marian_ckpt"
)
...
@@ -516,19 +664,6 @@ def convert(source_dir: Path, dest_dir):
...
@@ -516,19 +664,6 @@ def convert(source_dir: Path, dest_dir):
model
.
from_pretrained
(
dest_dir
)
# sanity check
model
.
from_pretrained
(
dest_dir
)
# sanity check
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--src"
,
type
=
str
,
help
=
"path to marian model dir"
,
default
=
"en-de"
)
parser
.
add_argument
(
"--dest"
,
type
=
str
,
default
=
None
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
source_dir
=
Path
(
args
.
src
)
assert
source_dir
.
exists
(),
f
"Source directory
{
source_dir
}
not found"
dest_dir
=
f
"converted-
{
source_dir
.
name
}
"
if
args
.
dest
is
None
else
args
.
dest
convert
(
source_dir
,
dest_dir
)
def
load_yaml
(
path
):
def
load_yaml
(
path
):
import
yaml
import
yaml
...
@@ -544,3 +679,23 @@ def save_json(content: Union[Dict, List], path: str) -> None:
...
@@ -544,3 +679,23 @@ def save_json(content: Union[Dict, List], path: str) -> None:
def
unzip
(
zip_path
:
str
,
dest_dir
:
str
)
->
None
:
def
unzip
(
zip_path
:
str
,
dest_dir
:
str
)
->
None
:
with
ZipFile
(
zip_path
,
"r"
)
as
zipObj
:
with
ZipFile
(
zip_path
,
"r"
)
as
zipObj
:
zipObj
.
extractall
(
dest_dir
)
zipObj
.
extractall
(
dest_dir
)
if
__name__
==
"__main__"
:
"""
To bulk convert, run
>>> from transformers.convert_marian_to_pytorch import make_tatoeba_registry, convert_all_sentencepiece_models
>>> reg = make_tatoeba_registry()
>>> convert_all_sentencepiece_models(model_list=reg) # saves to marian_converted
(bash) aws s3 sync marian_converted s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
"""
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--src"
,
type
=
str
,
help
=
"path to marian model dir"
,
default
=
"en-de"
)
parser
.
add_argument
(
"--dest"
,
type
=
str
,
default
=
None
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
source_dir
=
Path
(
args
.
src
)
assert
source_dir
.
exists
(),
f
"Source directory
{
source_dir
}
not found"
dest_dir
=
f
"converted-
{
source_dir
.
name
}
"
if
args
.
dest
is
None
else
args
.
dest
convert
(
source_dir
,
dest_dir
)
tests/test_modeling_marian.py
View file @
12d76241
...
@@ -205,6 +205,17 @@ class TestMarian_MT_EN(MarianIntegrationTest):
...
@@ -205,6 +205,17 @@ class TestMarian_MT_EN(MarianIntegrationTest):
self
.
_assert_generated_batch_equal_expected
()
self
.
_assert_generated_batch_equal_expected
()
class
TestMarian_eng_zho
(
MarianIntegrationTest
):
src
=
"eng"
tgt
=
"zho"
src_text
=
[
"My name is Wolfgang and I live in Berlin"
]
expected_text
=
[
"我叫沃尔夫冈 我住在柏林"
]
@
slow
def
test_batch_generation_eng_zho
(
self
):
self
.
_assert_generated_batch_equal_expected
()
class
TestMarian_en_ROMANCE
(
MarianIntegrationTest
):
class
TestMarian_en_ROMANCE
(
MarianIntegrationTest
):
"""Multilingual on target side."""
"""Multilingual on target side."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment