Commit 3dce7c9f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --input option to interactive.py to support reading from file

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/484

Differential Revision: D13880636

Pulled By: myleott

fbshipit-source-id: 984b2e1c3b281c28243102eb971ea45ec891d94e
parent 42be3ebd
......@@ -379,6 +379,8 @@ def add_interactive_args(parser):
# fmt: off
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
group.add_argument('--input', default='-', type=str, metavar='FILE',
help='file to read from; use - for stdin')
# fmt: on
......
......@@ -10,9 +10,10 @@ Translate raw text with a trained model. Batches data on-the-fly.
"""
from collections import namedtuple
import numpy as np
import fileinput
import sys
import numpy as np
import torch
from fairseq import data, options, tasks, tokenizer, utils
......@@ -23,9 +24,9 @@ Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(buffer_size):
def buffered_read(input, buffer_size):
buffer = []
for src_str in sys.stdin:
for src_str in fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")):
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
......@@ -165,12 +166,12 @@ def main(args):
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size):
for inputs in buffered_read(args.input, args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices)
results += process_batch(batch)
results.extend(process_batch(batch))
for i in np.argsort(indices):
result = results[i]
......
......@@ -282,6 +282,7 @@ def generate_main(data_dir, extra_flags=None):
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.input = '-'
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment