Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
572a1d55
"vscode:/vscode.git/clone" did not exist on "2e209c30cf6f2ba42001d0629dc6b7ce354b9a9d"
Unverified
Commit
572a1d55
authored
Jun 21, 2018
by
Myle Ott
Committed by
GitHub
Jun 21, 2018
Browse files
Fix `--output-format raw` option to preprocess.py (Fixes #188) (#190)
parent
70d61db4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
23 deletions
+36
-23
preprocess.py
preprocess.py
+13
-10
tests/test_binaries.py
tests/test_binaries.py
+23
-13
No files found.
preprocess.py
View file @
572a1d55
...
...
@@ -126,29 +126,32 @@ def main(args):
100
*
res
[
'nunk'
]
/
res
[
'ntok'
],
dict
.
unk_word
))
ds
.
finalize
(
dataset_dest_path
(
output_prefix
,
lang
,
'idx'
))
def
make_dataset
(
input_prefix
,
output_prefix
,
lang
,
output_format
=
'binary'
):
if
output_format
==
'binary'
:
def
make_dataset
(
input_prefix
,
output_prefix
,
lang
):
if
args
.
output_format
==
'binary'
:
make_binary_dataset
(
input_prefix
,
output_prefix
,
lang
)
elif
output_format
==
'raw'
:
elif
args
.
output_format
==
'raw'
:
# Copy original text file to destination folder
output_text_file
=
dest_path
(
output_prefix
,
lang
)
output_text_file
=
dest_path
(
output_prefix
+
'.{}-{}'
.
format
(
args
.
source_lang
,
args
.
target_lang
),
lang
,
)
shutil
.
copyfile
(
file_name
(
input_prefix
,
lang
),
output_text_file
)
def
make_all
(
args
,
make_dataset
,
lang
):
def
make_all
(
lang
):
if
args
.
trainpref
:
make_dataset
(
args
.
trainpref
,
'train'
,
lang
,
args
.
output_format
)
make_dataset
(
args
.
trainpref
,
'train'
,
lang
)
if
args
.
validpref
:
for
k
,
validpref
in
enumerate
(
args
.
validpref
.
split
(
','
)):
outprefix
=
'valid{}'
.
format
(
k
)
if
k
>
0
else
'valid'
make_dataset
(
validpref
,
outprefix
,
lang
,
args
.
output_format
)
make_dataset
(
validpref
,
outprefix
,
lang
)
if
args
.
testpref
:
for
k
,
testpref
in
enumerate
(
args
.
testpref
.
split
(
','
)):
outprefix
=
'test{}'
.
format
(
k
)
if
k
>
0
else
'test'
make_dataset
(
testpref
,
outprefix
,
lang
,
args
.
output_format
)
make_dataset
(
testpref
,
outprefix
,
lang
)
make_all
(
args
,
make_dataset
,
args
.
source_lang
)
make_all
(
args
.
source_lang
)
if
target
:
make_all
(
args
,
make_dataset
,
args
.
target_lang
)
make_all
(
args
.
target_lang
)
print
(
'| Wrote preprocessed data to {}'
.
format
(
args
.
destdir
))
...
...
tests/test_binaries.py
View file @
572a1d55
...
...
@@ -34,6 +34,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
)
generate_main
(
data_dir
)
def
test_raw
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fconv_raw'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
,
[
'--output-format'
,
'raw'
])
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--raw-text'
])
generate_main
(
data_dir
,
[
'--raw-text'
])
def
test_fp16
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_fp16'
)
as
data_dir
:
...
...
@@ -144,9 +152,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
_create_dummy_data
(
'test.out'
)
def
preprocess_translation_data
(
data_dir
):
def
preprocess_translation_data
(
data_dir
,
extra_flags
=
None
):
preprocess_parser
=
preprocess
.
get_parser
()
preprocess_args
=
preprocess_parser
.
parse_args
([
preprocess_args
=
preprocess_parser
.
parse_args
(
[
'--source-lang'
,
'in'
,
'--target-lang'
,
'out'
,
'--trainpref'
,
os
.
path
.
join
(
data_dir
,
'train'
),
...
...
@@ -155,7 +164,8 @@ def preprocess_translation_data(data_dir):
'--thresholdtgt'
,
'0'
,
'--thresholdsrc'
,
'0'
,
'--destdir'
,
data_dir
,
])
]
+
(
extra_flags
or
[]),
)
preprocess
.
main
(
preprocess_args
)
...
...
@@ -181,7 +191,7 @@ def train_translation_model(data_dir, arch, extra_flags=None):
train
.
main
(
train_args
)
def
generate_main
(
data_dir
):
def
generate_main
(
data_dir
,
extra_flags
=
None
):
generate_parser
=
options
.
get_generation_parser
()
generate_args
=
options
.
parse_args_and_arch
(
generate_parser
,
...
...
@@ -193,7 +203,7 @@ def generate_main(data_dir):
'--max-len-b'
,
'5'
,
'--gen-subset'
,
'valid'
,
'--no-progress-bar'
,
],
]
+
(
extra_flags
or
[])
,
)
# evaluate model in batch mode
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment