Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b9ecd92e
Unverified
Commit
b9ecd92e
authored
Aug 10, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 10, 2020
Browse files
[s2s] Script to save wmt data to disk (#6403)
parent
00bb0b25
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
0 deletions
+48
-0
examples/seq2seq/download_wmt.py
examples/seq2seq/download_wmt.py
+48
-0
No files found.
examples/seq2seq/download_wmt.py
0 → 100644
View file @
b9ecd92e
from
pathlib
import
Path
import
fire
from
tqdm
import
tqdm
def
download_wmt_dataset
(
src_lang
,
tgt_lang
,
dataset
=
"wmt19"
,
save_dir
=
None
)
->
None
:
"""Download a dataset using the nlp package and save it to the format expected by finetune.py
Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target.
Args:
src_lang: <str> source language
tgt_lang: <str> target language
dataset: <str> like wmt19 (if you don't know, try wmt19).
save_dir: <str>, where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}'
Usage:
>>> download_wmt_dataset('en', 'ru', dataset='wmt19') # saves to wmt19_en_ru
"""
try
:
import
nlp
except
(
ModuleNotFoundError
,
ImportError
):
raise
ImportError
(
"run pip install nlp"
)
pair
=
f
"
{
src_lang
}
-
{
tgt_lang
}
"
ds
=
nlp
.
load_dataset
(
dataset
,
pair
)
if
save_dir
is
None
:
save_dir
=
f
"
{
dataset
}
-
{
pair
}
"
save_dir
=
Path
(
save_dir
)
save_dir
.
mkdir
(
exist_ok
=
True
)
for
split
in
tqdm
(
ds
.
keys
()):
tr_list
=
list
(
ds
[
split
])
data
=
[
x
[
"translation"
]
for
x
in
tr_list
]
src
,
tgt
=
[],
[]
for
example
in
data
:
src
.
append
(
example
[
src_lang
])
tgt
.
append
(
example
[
tgt_lang
])
if
split
==
"validation"
:
split
=
"val"
# to save to val.source, val.target like summary datasets
src_path
=
save_dir
.
joinpath
(
f
"
{
split
}
.source"
)
src_path
.
open
(
"w+"
).
write
(
"
\n
"
.
join
(
src
))
tgt_path
=
save_dir
.
joinpath
(
f
"
{
split
}
.target"
)
tgt_path
.
open
(
"w+"
).
write
(
"
\n
"
.
join
(
tgt
))
print
(
f
"saved dataset to
{
save_dir
}
"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
download_wmt_dataset
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment