Commit 59d599a2 authored by Myle Ott's avatar Myle Ott
Browse files

Move helper functions from generate.py to fairseq/dictionary.py

parent af86c1ac
...@@ -38,13 +38,31 @@ class Dictionary(object): ...@@ -38,13 +38,31 @@ class Dictionary(object):
return self.indices[sym] return self.indices[sym]
return self.unk_index return self.unk_index
def string(self, tensor): def string(self, tensor, bpe_symbol=None, escape_unk=False):
if torch.is_tensor(tensor) and tensor.dim() == 2: """Helper for converting a tensor of token indices to a string.
sentences = [self.string(line) for line in tensor]
return '\n'.join(sentences)
eos = self.eos() Can optionally remove BPE symbols or escape <unk> words.
return ' '.join([self[i] for i in tensor if i != eos]) """
if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.to_string(t) for t in tensor)
def token_string(i):
if i == self.unk():
return self.unk_string(escape_unk)
else:
return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return '<{}>'.format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1): def add_symbol(self, word, n=1):
"""Adds a word to the dictionary""" """Adds a word to the dictionary"""
......
...@@ -10,7 +10,7 @@ import sys ...@@ -10,7 +10,7 @@ import sys
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import bleu, options, utils, tokenizer from fairseq import bleu, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -54,14 +54,18 @@ def main(): ...@@ -54,14 +54,18 @@ def main():
model.make_generation_fast_(not args.no_beamable_mm) model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator # Initialize generator
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam, translator = SequenceGenerator(
stop_early=(not args.no_early_stop), models, dataset.dst_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
len_penalty=args.lenpen) )
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
align_dict = {} align_dict = {}
if args.unk_replace_dict != '': if args.unk_replace_dict != '':
assert args.interactive, "Unkown words replacing requires access to original source and is only" \ assert args.interactive, \
"supported in interactive mode" 'Unknown word replacement requires access to original source and is only supported in interactive mode'
with open(args.unk_replace_dict, 'r') as f: with open(args.unk_replace_dict, 'r') as f:
for line in f: for line in f:
l = line.split() l = line.split()
...@@ -80,27 +84,23 @@ def main(): ...@@ -80,27 +84,23 @@ def main():
hypo_tokens[i] = src_token hypo_tokens[i] = src_token
return ' '.join(hypo_tokens) return ' '.join(hypo_tokens)
if use_cuda:
translator.cuda()
bpe_symbol = '@@ ' if args.remove_bpe else None bpe_symbol = '@@ ' if args.remove_bpe else None
def display_hypotheses(id, src, orig, ref, hypos): def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet: if args.quiet:
return return
id_str = '' if id is None else '-{}'.format(id) id_str = '' if id is None else '-{}'.format(id)
src_str = to_sentence(dataset.src_dict, src, bpe_symbol) src_str = dataset.src_dict.string(src, bpe_symbol)
print('S{}\t{}'.format(id_str, src_str)) print('S{}\t{}'.format(id_str, src_str))
if orig is not None: if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip())) print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None: if ref is not None:
print('T{}\t{}'.format(id_str, to_sentence(dataset.dst_dict, ref, bpe_symbol, ref_unk=True))) print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, bpe_symbol, escape_unk=True)))
for hypo in hypos: for hypo in hypos:
hypo_str = to_sentence(dataset.dst_dict, hypo['tokens'], bpe_symbol) hypo_str = dataset.dst_dict.string(hypo['tokens'], bpe_symbol)
align_str = ' '.join(map(str, hypo['alignment'])) align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '': if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, unk_symbol(dataset.dst_dict)) hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
print('H{}\t{}\t{}'.format( print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str)) print('A{}\t{}'.format(id_str, align_str))
if args.interactive: if args.interactive:
...@@ -121,7 +121,7 @@ def main(): ...@@ -121,7 +121,7 @@ def main():
if not args.remove_bpe: if not args.remove_bpe:
return tokens return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0 assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = to_sentence(dataset.dst_dict, tokens, bpe_symbol) hypo_minus_bpe = dataset.dst_dict.string(tokens, bpe_symbol)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True) return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score # Generate and compute BLEU score
...@@ -151,25 +151,5 @@ def main(): ...@@ -151,25 +151,5 @@ def main():
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def to_token(dict, i, runk):
return runk if i == dict.unk() else dict[i]
def unk_symbol(dict, ref_unk=False):
return '<{}>'.format(dict.unk_word) if ref_unk else dict.unk_word
def to_sentence(dict, tokens, bpe_symbol=None, ref_unk=False):
if torch.is_tensor(tokens) and tokens.dim() == 2:
sentences = [to_sentence(dict, token) for token in tokens]
return '\n'.join(sentences)
eos = dict.eos()
runk = unk_symbol(dict, ref_unk=ref_unk)
sent = ' '.join([to_token(dict, i, runk) for i in tokens if i != eos])
if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment