Commit 663fd806 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

implement batching in interactive mode

parent 4ce453b1
...@@ -25,10 +25,12 @@ def get_training_parser(): ...@@ -25,10 +25,12 @@ def get_training_parser():
return parser return parser
def get_generation_parser(): def get_generation_parser(interactive=False):
parser = get_parser('Generation') parser = get_parser('Generation')
add_dataset_args(parser, gen=True) add_dataset_args(parser, gen=True)
add_generation_args(parser) add_generation_args(parser)
if interactive:
add_interactive_args(parser)
return parser return parser
...@@ -242,6 +244,12 @@ def add_generation_args(parser): ...@@ -242,6 +244,12 @@ def add_generation_args(parser):
return group return group
def add_interactive_args(parser):
group = parser.add_argument_group('Interactive')
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
def add_model_args(parser): def add_model_args(parser):
group = parser.add_argument_group('Model configuration') group = parser.add_argument_group('Model configuration')
......
...@@ -6,20 +6,60 @@ ...@@ -6,20 +6,60 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import numpy as np
import sys import sys
import torch import torch
from collections import namedtuple
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import options, tokenizer, utils from fairseq import options, tokenizer, utils
from fairseq.data import LanguagePairDataset
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')
def buffered_read(buffer_size):
buffer = []
for src_str in sys.stdin:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(lines, batch_size, src_dict):
tokens = [tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() for src_str in lines]
lengths = [t.numel() for t in tokens]
indices = np.argsort(lengths)
num_batches = np.ceil(len(indices) / batch_size)
batches = np.array_split(indices, num_batches)
for batch_idxs in batches:
batch_toks = [tokens[i] for i in batch_idxs]
batch_toks = LanguagePairDataset.collate_tokens(batch_toks, src_dict.pad(), src_dict.eos(),
LanguagePairDataset.LEFT_PAD_SOURCE,
move_eos_to_beginning=False)
yield Batch(
srcs=[lines[i] for i in batch_idxs],
tokens=batch_toks,
lengths=tokens[0].new([lengths[i] for i in batch_idxs]),
), batch_idxs
def main(args): def main(args):
print(args) print(args)
assert not args.sampling or args.nbest == args.beam, \ assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam' '--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences, \ assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size is not supported in interactive mode' '--max-sentences/--batch-size cannot be larger than --buffer-size'
if args.buffer_size < 1:
args.buffer_size = 1
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
...@@ -49,19 +89,12 @@ def main(args): ...@@ -49,19 +89,12 @@ def main(args):
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
print('| Type the input sentence and press return:') def make_result(src_str, hypos):
for src_str in sys.stdin: result = Translation(
src_str = src_str.strip() src_str='O\t{}'.format(src_str),
src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() hypos=[],
if use_cuda: alignments=[],
src_tokens = src_tokens.cuda()
src_lengths = src_tokens.new([src_tokens.numel()])
translations = translator.generate(
Variable(src_tokens.view(1, -1)),
Variable(src_lengths.view(-1)),
) )
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions # Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]: for hypo in hypos[:min(len(hypos), args.nbest)]:
...@@ -73,11 +106,45 @@ def main(args): ...@@ -73,11 +106,45 @@ def main(args):
dst_dict=dst_dict, dst_dict=dst_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
print('H\t{}\t{}'.format(hypo['score'], hypo_str)) result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))) result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
return result
def process_batch(batch):
tokens = batch.tokens
lengths = batch.lengths
if use_cuda:
tokens = tokens.cuda()
lengths = lengths.cuda()
translations = translator.generate(
Variable(tokens),
Variable(lengths),
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, max(1, args.max_sentences or 1), src_dict):
indices.extend(batch_indices)
results += process_batch(batch)
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments):
print(hypo)
print(align)
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser() parser = options.get_generation_parser(interactive=True)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment