"...csrc/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b56f17ae1ae8a5d08067c7f7444af21fb3b59ca6"
Commit 3dce7c9f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add --input option to interactive.py to support reading from file

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/484

Differential Revision: D13880636

Pulled By: myleott

fbshipit-source-id: 984b2e1c3b281c28243102eb971ea45ec891d94e
parent 42be3ebd
...@@ -379,6 +379,8 @@ def add_interactive_args(parser): ...@@ -379,6 +379,8 @@ def add_interactive_args(parser):
# fmt: off # fmt: off
group.add_argument('--buffer-size', default=0, type=int, metavar='N', group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them') help='read this many sentences into a buffer before processing them')
group.add_argument('--input', default='-', type=str, metavar='FILE',
help='file to read from; use - for stdin')
# fmt: on # fmt: on
......
...@@ -10,9 +10,10 @@ Translate raw text with a trained model. Batches data on-the-fly. ...@@ -10,9 +10,10 @@ Translate raw text with a trained model. Batches data on-the-fly.
""" """
from collections import namedtuple from collections import namedtuple
import numpy as np import fileinput
import sys import sys
import numpy as np
import torch import torch
from fairseq import data, options, tasks, tokenizer, utils from fairseq import data, options, tasks, tokenizer, utils
...@@ -23,9 +24,9 @@ Batch = namedtuple('Batch', 'srcs tokens lengths') ...@@ -23,9 +24,9 @@ Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(buffer_size): def buffered_read(input, buffer_size):
buffer = [] buffer = []
for src_str in sys.stdin: for src_str in fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")):
buffer.append(src_str.strip()) buffer.append(src_str.strip())
if len(buffer) >= buffer_size: if len(buffer) >= buffer_size:
yield buffer yield buffer
...@@ -165,12 +166,12 @@ def main(args): ...@@ -165,12 +166,12 @@ def main(args):
if args.buffer_size > 1: if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size) print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:') print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size): for inputs in buffered_read(args.input, args.buffer_size):
indices = [] indices = []
results = [] results = []
for batch, batch_indices in make_batches(inputs, args, task, max_positions): for batch, batch_indices in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices) indices.extend(batch_indices)
results += process_batch(batch) results.extend(process_batch(batch))
for i in np.argsort(indices): for i in np.argsort(indices):
result = results[i] result = results[i]
......
...@@ -282,6 +282,7 @@ def generate_main(data_dir, extra_flags=None): ...@@ -282,6 +282,7 @@ def generate_main(data_dir, extra_flags=None):
# evaluate model interactively # evaluate model interactively
generate_args.buffer_size = 0 generate_args.buffer_size = 0
generate_args.input = '-'
generate_args.max_sentences = None generate_args.max_sentences = None
orig_stdin = sys.stdin orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n') sys.stdin = StringIO('h e l l o\n')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment