Unverified Commit 572a1d55 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix `--output-format raw` option to preprocess.py (Fixes #188) (#190)

parent 70d61db4
......@@ -126,29 +126,32 @@ def main(args):
100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
if output_format == 'binary':
def make_dataset(input_prefix, output_prefix, lang):
if args.output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw':
elif args.output_format == 'raw':
# Copy original text file to destination folder
output_text_file = dest_path(output_prefix, lang)
output_text_file = dest_path(
output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
lang,
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(args, make_dataset, lang):
def make_all(lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang, args.output_format)
make_dataset(args.trainpref, 'train', lang)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, lang, args.output_format)
make_dataset(validpref, outprefix, lang)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, lang, args.output_format)
make_dataset(testpref, outprefix, lang)
make_all(args, make_dataset, args.source_lang)
make_all(args.source_lang)
if target:
make_all(args, make_dataset, args.target_lang)
make_all(args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
......
......@@ -34,6 +34,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir)
def test_raw(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--output-format', 'raw'])
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--raw-text'])
generate_main(data_dir, ['--raw-text'])
def test_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fp16') as data_dir:
......@@ -144,9 +152,10 @@ def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
_create_dummy_data('test.out')
def preprocess_translation_data(data_dir):
def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([
preprocess_args = preprocess_parser.parse_args(
[
'--source-lang', 'in',
'--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'),
......@@ -155,7 +164,8 @@ def preprocess_translation_data(data_dir):
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir,
])
] + (extra_flags or []),
)
preprocess.main(preprocess_args)
......@@ -181,7 +191,7 @@ def train_translation_model(data_dir, arch, extra_flags=None):
train.main(train_args)
def generate_main(data_dir):
def generate_main(data_dir, extra_flags=None):
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
......@@ -193,7 +203,7 @@ def generate_main(data_dir):
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
],
] + (extra_flags or []),
)
# evaluate model in batch mode
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment