Commit 59d599a2 authored by Myle Ott's avatar Myle Ott
Browse files

Move helper functions from generate.py to fairseq/dictionary.py

parent af86c1ac
......@@ -38,13 +38,31 @@ class Dictionary(object):
return self.indices[sym]
return self.unk_index
def string(self, tensor):
if torch.is_tensor(tensor) and tensor.dim() == 2:
sentences = [self.string(line) for line in tensor]
return '\n'.join(sentences)
def string(self, tensor, bpe_symbol=None, escape_unk=False):
"""Helper for converting a tensor of token indices to a string.
eos = self.eos()
return ' '.join([self[i] for i in tensor if i != eos])
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.to_string(t) for t in tensor)
def token_string(i):
if i == self.unk():
return self.unk_string(escape_unk)
else:
return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return '<{}>'.format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
......
......@@ -10,7 +10,7 @@ import sys
import torch
from torch.autograd import Variable
from fairseq import bleu, options, utils, tokenizer
from fairseq import bleu, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
......@@ -54,14 +54,18 @@ def main():
model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,
stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen)
translator = SequenceGenerator(
models, dataset.dst_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
align_dict = {}
if args.unk_replace_dict != '':
assert args.interactive, "Unkown words replacing requires access to original source and is only" \
"supported in interactive mode"
assert args.interactive, \
'Unknown word replacement requires access to original source and is only supported in interactive mode'
with open(args.unk_replace_dict, 'r') as f:
for line in f:
l = line.split()
......@@ -80,27 +84,23 @@ def main():
hypo_tokens[i] = src_token
return ' '.join(hypo_tokens)
if use_cuda:
translator.cuda()
bpe_symbol = '@@ ' if args.remove_bpe else None
def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet:
return
id_str = '' if id is None else '-{}'.format(id)
src_str = to_sentence(dataset.src_dict, src, bpe_symbol)
src_str = dataset.src_dict.string(src, bpe_symbol)
print('S{}\t{}'.format(id_str, src_str))
if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None:
print('T{}\t{}'.format(id_str, to_sentence(dataset.dst_dict, ref, bpe_symbol, ref_unk=True)))
print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, bpe_symbol, escape_unk=True)))
for hypo in hypos:
hypo_str = to_sentence(dataset.dst_dict, hypo['tokens'], bpe_symbol)
hypo_str = dataset.dst_dict.string(hypo['tokens'], bpe_symbol)
align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, unk_symbol(dataset.dst_dict))
print('H{}\t{}\t{}'.format(
id_str, hypo['score'], hypo_str))
hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str))
if args.interactive:
......@@ -121,7 +121,7 @@ def main():
if not args.remove_bpe:
return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = to_sentence(dataset.dst_dict, tokens, bpe_symbol)
hypo_minus_bpe = dataset.dst_dict.string(tokens, bpe_symbol)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score
......@@ -151,25 +151,5 @@ def main():
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def to_token(dict, i, runk):
return runk if i == dict.unk() else dict[i]
def unk_symbol(dict, ref_unk=False):
return '<{}>'.format(dict.unk_word) if ref_unk else dict.unk_word
def to_sentence(dict, tokens, bpe_symbol=None, ref_unk=False):
if torch.is_tensor(tokens) and tokens.dim() == 2:
sentences = [to_sentence(dict, token) for token in tokens]
return '\n'.join(sentences)
eos = dict.eos()
runk = unk_symbol(dict, ref_unk=ref_unk)
sent = ' '.join([to_token(dict, i, runk) for i in tokens if i != eos])
if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment