Commit b85fb035 authored by Jiatao Gu's avatar Jiatao Gu Committed by Facebook Github Bot
Browse files

Enable to print the history of NAT; fix LevT decoding bug (#908)

Summary:
(1) Enable to print the iterative refinement history for all NAT models by setting --retain-iter-history during decoding;
(2) Fix a small bug in the decoding process in Levenshtein Transformer.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/908

Differential Revision: D18493234

Pulled By: MultiPath

fbshipit-source-id: 9e7702adcea49f39d3c10b5349b5a9ae66399a24
parent e26ee47a
...@@ -123,3 +123,7 @@ data-bin/ ...@@ -123,3 +123,7 @@ data-bin/
# Cython-generated C++ source files # Cython-generated C++ source files
/fairseq/data/data_utils_fast.cpp /fairseq/data/data_utils_fast.cpp
/fairseq/data/token_block_utils_fast.cpp /fairseq/data/token_block_utils_fast.cpp
# VSCODE
.vscode/ftp-sync.json
.vscode/settings.json
...@@ -16,6 +16,7 @@ DecoderOut = namedtuple('IterativeRefinementDecoderOut', [ ...@@ -16,6 +16,7 @@ DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'attn', 'attn',
'step', 'step',
'max_step', 'max_step',
'history'
]) ])
...@@ -30,6 +31,7 @@ class IterativeRefinementGenerator(object): ...@@ -30,6 +31,7 @@ class IterativeRefinementGenerator(object):
decoding_format=None, decoding_format=None,
retain_dropout=False, retain_dropout=False,
adaptive=True, adaptive=True,
retain_history=False,
): ):
""" """
Generates translations based on iterative refinement. Generates translations based on iterative refinement.
...@@ -53,6 +55,7 @@ class IterativeRefinementGenerator(object): ...@@ -53,6 +55,7 @@ class IterativeRefinementGenerator(object):
self.max_ratio = max_ratio self.max_ratio = max_ratio
self.decoding_format = decoding_format self.decoding_format = decoding_format
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.retain_history = retain_history
self.adaptive = adaptive self.adaptive = adaptive
self.models = models self.models = models
...@@ -123,6 +126,9 @@ class IterativeRefinementGenerator(object): ...@@ -123,6 +126,9 @@ class IterativeRefinementGenerator(object):
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out.output_tokens.clone() prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.retain_history:
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
finalized = [[] for _ in range(bsz)] finalized = [[] for _ in range(bsz)]
def is_a_loop(x, y, s, a): def is_a_loop(x, y, s, a):
...@@ -139,7 +145,12 @@ class IterativeRefinementGenerator(object): ...@@ -139,7 +145,12 @@ class IterativeRefinementGenerator(object):
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
cutoff = prev_out_token.ne(self.pad) cutoff = prev_out_token.ne(self.pad)
tokens = prev_out_token[cutoff] tokens = prev_out_token[cutoff]
if prev_out_score is None:
scores, score = None, None
else:
scores = prev_out_score[cutoff] scores = prev_out_score[cutoff]
score = scores.mean()
if prev_out_attn is None: if prev_out_attn is None:
hypo_attn, alignment = None, None hypo_attn, alignment = None, None
else: else:
...@@ -149,7 +160,7 @@ class IterativeRefinementGenerator(object): ...@@ -149,7 +160,7 @@ class IterativeRefinementGenerator(object):
"steps": step, "steps": step,
"tokens": tokens, "tokens": tokens,
"positional_scores": scores, "positional_scores": scores,
"score": scores.mean(), "score": score,
"hypo_attn": hypo_attn, "hypo_attn": hypo_attn,
"alignment": alignment, "alignment": alignment,
} }
...@@ -195,6 +206,9 @@ class IterativeRefinementGenerator(object): ...@@ -195,6 +206,9 @@ class IterativeRefinementGenerator(object):
None if decoder_out.attn is None else decoder_out.attn[terminated] None if decoder_out.attn is None else decoder_out.attn[terminated]
) )
if self.retain_history:
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
for i in range(finalized_idxs.size(0)): for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [ finalized[finalized_idxs[i]] = [
finalized_hypos( finalized_hypos(
...@@ -204,6 +218,18 @@ class IterativeRefinementGenerator(object): ...@@ -204,6 +218,18 @@ class IterativeRefinementGenerator(object):
None if finalized_attn is None else finalized_attn[i], None if finalized_attn is None else finalized_attn[i],
) )
] ]
if self.retain_history:
finalized[finalized_idxs[i]][0]['history'] = []
for j in range(len(finalized_history_tokens)):
finalized[finalized_idxs[i]][0]['history'].append(
finalized_hypos(
step,
finalized_history_tokens[j][i],
None, None
)
)
# check if all terminated # check if all terminated
if terminated.sum() == terminated.size(0): if terminated.sum() == terminated.size(0):
break break
...@@ -214,6 +240,7 @@ class IterativeRefinementGenerator(object): ...@@ -214,6 +240,7 @@ class IterativeRefinementGenerator(object):
output_tokens=decoder_out.output_tokens[not_terminated], output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated], output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None, attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
history=[h[not_terminated] for h in decoder_out.history] if decoder_out.history is not None else None
) )
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze()) encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated] sent_idxs = sent_idxs[not_terminated]
......
...@@ -60,6 +60,7 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -60,6 +60,7 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens = decoder_out.output_tokens output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores output_scores = decoder_out.output_scores
history = decoder_out.history
# execute the decoder # execute the decoder
output_masks = output_tokens.eq(self.unk) output_masks = output_tokens.eq(self.unk)
...@@ -69,6 +70,9 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -69,6 +70,9 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks]) output_scores.masked_scatter_(output_masks, _scores[output_masks])
if history is not None:
history.append(output_tokens.clone())
# skeptical decoding (depend on the maximum decoding steps.) # skeptical decoding (depend on the maximum decoding steps.)
if (step + 1) < max_step: if (step + 1) < max_step:
skeptical_mask = _skeptical_unmasking( skeptical_mask = _skeptical_unmasking(
...@@ -78,10 +82,14 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -78,10 +82,14 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_fill_(skeptical_mask, self.unk) output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0) output_scores.masked_fill_(skeptical_mask, 0.0)
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace( return decoder_out._replace(
output_tokens=output_tokens, output_tokens=output_tokens,
output_scores=output_scores, output_scores=output_scores,
attn=None, attn=None,
history=history
) )
......
...@@ -172,6 +172,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -172,6 +172,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
output_tokens = decoder_out.output_tokens output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores output_scores = decoder_out.output_scores
history = decoder_out.history
# TODO: decoding for InsertionTransformer # TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins( word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out output_tokens, encoder_out=encoder_out
...@@ -188,10 +190,15 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -188,10 +190,15 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max() cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off] output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off] output_scores = output_scores[:, :cut_off]
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace( return decoder_out._replace(
output_tokens=output_tokens, output_tokens=output_tokens,
output_scores=output_scores, output_scores=output_scores,
attn=None, attn=None,
history=history
) )
......
...@@ -410,6 +410,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -410,6 +410,7 @@ class LevenshteinTransformerModel(TransformerModel):
output_tokens = decoder_out.output_tokens output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores output_scores = decoder_out.output_scores
attn = decoder_out.attn attn = decoder_out.attn
history = decoder_out.history
bsz = output_tokens.size(0) bsz = output_tokens.size(0)
if max_ratio is None: if max_ratio is None:
...@@ -446,6 +447,9 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -446,6 +447,9 @@ class LevenshteinTransformerModel(TransformerModel):
output_scores = _fill(output_scores, can_del_word, _scores, 0) output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.) attn = _fill(attn, can_del_word, _attn, 0.)
if history is not None:
history.append(output_tokens.clone())
# insert placeholders # insert placeholders
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0: if can_ins_mask.sum() != 0:
...@@ -472,6 +476,9 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -472,6 +476,9 @@ class LevenshteinTransformerModel(TransformerModel):
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_mask, _scores, 0) output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
if history is not None:
history.append(output_tokens.clone())
# insert words # insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
if can_ins_word.sum() != 0: if can_ins_word.sum() != 0:
...@@ -480,8 +487,6 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -480,8 +487,6 @@ class LevenshteinTransformerModel(TransformerModel):
_skip_encoder_out(self.encoder, encoder_out, can_ins_word) _skip_encoder_out(self.encoder, encoder_out, can_ins_word)
) )
word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1) word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1)
word_ins_pred = word_ins_score.max(-1)[1]
_tokens, _scores = _apply_ins_words( _tokens, _scores = _apply_ins_words(
output_tokens[can_ins_word], output_tokens[can_ins_word],
output_scores[can_ins_word], output_scores[can_ins_word],
...@@ -494,6 +499,9 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -494,6 +499,9 @@ class LevenshteinTransformerModel(TransformerModel):
output_scores = _fill(output_scores, can_ins_word, _scores, 0) output_scores = _fill(output_scores, can_ins_word, _scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0.) attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
if history is not None:
history.append(output_tokens.clone())
# delete some unnecessary paddings # delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max() cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off] output_tokens = output_tokens[:, :cut_off]
...@@ -504,6 +512,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -504,6 +512,7 @@ class LevenshteinTransformerModel(TransformerModel):
output_tokens=output_tokens, output_tokens=output_tokens,
output_scores=output_scores, output_scores=output_scores,
attn=attn, attn=attn,
history=history
) )
def initialize_output_tokens(self, encoder_out, src_tokens): def initialize_output_tokens(self, encoder_out, src_tokens):
...@@ -520,6 +529,7 @@ class LevenshteinTransformerModel(TransformerModel): ...@@ -520,6 +529,7 @@ class LevenshteinTransformerModel(TransformerModel):
attn=None, attn=None,
step=0, step=0,
max_step=0, max_step=0,
history=None
) )
......
...@@ -117,6 +117,7 @@ class NATransformerModel(TransformerModel): ...@@ -117,6 +117,7 @@ class NATransformerModel(TransformerModel):
step = decoder_out.step step = decoder_out.step
output_tokens = decoder_out.output_tokens output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores output_scores = decoder_out.output_scores
history = decoder_out.history
# execute the decoder # execute the decoder
output_masks = output_tokens.ne(self.pad) output_masks = output_tokens.ne(self.pad)
...@@ -128,11 +129,14 @@ class NATransformerModel(TransformerModel): ...@@ -128,11 +129,14 @@ class NATransformerModel(TransformerModel):
) )
output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks]) output_scores.masked_scatter_(output_masks, _scores[output_masks])
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace( return decoder_out._replace(
output_tokens=output_tokens, output_tokens=output_tokens,
output_scores=output_scores, output_scores=output_scores,
attn=None, attn=None,
history=history
) )
def initialize_output_tokens(self, encoder_out, src_tokens): def initialize_output_tokens(self, encoder_out, src_tokens):
...@@ -160,6 +164,7 @@ class NATransformerModel(TransformerModel): ...@@ -160,6 +164,7 @@ class NATransformerModel(TransformerModel):
attn=None, attn=None,
step=0, step=0,
max_step=0, max_step=0,
history=None
) )
......
...@@ -506,6 +506,8 @@ def add_generation_args(parser): ...@@ -506,6 +506,8 @@ def add_generation_args(parser):
help='maximum iterations for iterative refinement.') help='maximum iterations for iterative refinement.')
group.add_argument('--iter-decode-force-max-iter', action='store_true', group.add_argument('--iter-decode-force-max-iter', action='store_true',
help='if set, run exact the maximum number of iterations without early stop') help='if set, run exact the maximum number of iterations without early stop')
group.add_argument('--retain-iter-history', action='store_true',
help='if set, decoding returns the whole history of iterative refinement')
# special decoding format for advanced decoding. # special decoding format for advanced decoding.
group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs']) group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs'])
......
...@@ -131,7 +131,8 @@ class TranslationLevenshteinTask(TranslationTask): ...@@ -131,7 +131,8 @@ class TranslationLevenshteinTask(TranslationTask):
eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0), eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
max_iter=getattr(args, 'iter_decode_max_iter', 10), max_iter=getattr(args, 'iter_decode_max_iter', 10),
decoding_format=getattr(args, 'decoding_format', None), decoding_format=getattr(args, 'decoding_format', None),
adaptive=not getattr(args, 'iter_decode_force_max_iter', False)) adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
retain_history=getattr(args, 'retain_iter_history', False))
def train_step(self, def train_step(self,
sample, sample,
......
...@@ -162,6 +162,15 @@ def main(args): ...@@ -162,6 +162,15 @@ def main(args):
if args.print_step: if args.print_step:
print('I-{}\t{}'.format(sample_id, hypo['steps'])) print('I-{}\t{}'.format(sample_id, hypo['steps']))
if getattr(args, 'retain_iter_history', False):
print("\n".join([
'E-{}_{}\t{}'.format(
sample_id, step,
utils.post_process_prediction(
h['tokens'].int().cpu(),
src_str, None, None, tgt_dict, None)[1])
for step, h in enumerate(hypo['history'])]))
# Score only the top hypothesis # Score only the top hypothesis
if has_target and j == 0: if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None: if align_dict is not None or args.remove_bpe is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment