"tests/python/pytorch/nn/test_nn.py" did not exist on "3fe5eea791b84280513bcb495aa7c4e1bd0fad9d"
Unverified Commit 572a1d55 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix `--output-format raw` option to preprocess.py (Fixes #188) (#190)

parent 70d61db4
...@@ -126,29 +126,32 @@ def main(args): ...@@ -126,29 +126,32 @@ def main(args):
100 * res['nunk'] / res['ntok'], dict.unk_word)) 100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize(dataset_dest_path(output_prefix, lang, 'idx')) ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'): def make_dataset(input_prefix, output_prefix, lang):
if output_format == 'binary': if args.output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang) make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw': elif args.output_format == 'raw':
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = dest_path(output_prefix, lang) output_text_file = dest_path(
output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
lang,
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file) shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(args, make_dataset, lang): def make_all(lang):
if args.trainpref: if args.trainpref:
make_dataset(args.trainpref, 'train', lang, args.output_format) make_dataset(args.trainpref, 'train', lang)
if args.validpref: if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')): for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid' outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, lang, args.output_format) make_dataset(validpref, outprefix, lang)
if args.testpref: if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')): for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test' outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, lang, args.output_format) make_dataset(testpref, outprefix, lang)
make_all(args, make_dataset, args.source_lang) make_all(args.source_lang)
if target: if target:
make_all(args, make_dataset, args.target_lang) make_all(args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir)) print('| Wrote preprocessed data to {}'.format(args.destdir))
......
...@@ -34,6 +34,14 @@ class TestTranslation(unittest.TestCase): ...@@ -34,6 +34,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en') train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir) generate_main(data_dir)
def test_raw(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--output-format', 'raw'])
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--raw-text'])
generate_main(data_dir, ['--raw-text'])
def test_fp16(self): def test_fp16(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fp16') as data_dir: with tempfile.TemporaryDirectory('test_fp16') as data_dir:
...@@ -144,9 +152,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20): ...@@ -144,9 +152,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
_create_dummy_data('test.out') _create_dummy_data('test.out')
def preprocess_translation_data(data_dir): def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = preprocess.get_parser() preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([ preprocess_args = preprocess_parser.parse_args(
[
'--source-lang', 'in', '--source-lang', 'in',
'--target-lang', 'out', '--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'), '--trainpref', os.path.join(data_dir, 'train'),
...@@ -155,7 +164,8 @@ def preprocess_translation_data(data_dir): ...@@ -155,7 +164,8 @@ def preprocess_translation_data(data_dir):
'--thresholdtgt', '0', '--thresholdtgt', '0',
'--thresholdsrc', '0', '--thresholdsrc', '0',
'--destdir', data_dir, '--destdir', data_dir,
]) ] + (extra_flags or []),
)
preprocess.main(preprocess_args) preprocess.main(preprocess_args)
...@@ -181,7 +191,7 @@ def train_translation_model(data_dir, arch, extra_flags=None): ...@@ -181,7 +191,7 @@ def train_translation_model(data_dir, arch, extra_flags=None):
train.main(train_args) train.main(train_args)
def generate_main(data_dir): def generate_main(data_dir, extra_flags=None):
generate_parser = options.get_generation_parser() generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch( generate_args = options.parse_args_and_arch(
generate_parser, generate_parser,
...@@ -193,7 +203,7 @@ def generate_main(data_dir): ...@@ -193,7 +203,7 @@ def generate_main(data_dir):
'--max-len-b', '5', '--max-len-b', '5',
'--gen-subset', 'valid', '--gen-subset', 'valid',
'--no-progress-bar', '--no-progress-bar',
], ] + (extra_flags or []),
) )
# evaluate model in batch mode # evaluate model in batch mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment