Commit b85fb035 authored by Jiatao Gu's avatar Jiatao Gu Committed by Facebook Github Bot
Browse files

Enable to print the history of NAT; fix LevT decoding bug (#908)

Summary:
(1) Enable to print the iterative refinement history for all NAT models by setting --retain-iter-history during decoding;
(2) Fix a small bug in the decoding process in Levenshtein Transformer.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/908

Differential Revision: D18493234

Pulled By: MultiPath

fbshipit-source-id: 9e7702adcea49f39d3c10b5349b5a9ae66399a24
parent e26ee47a
......@@ -123,3 +123,7 @@ data-bin/
# Cython-generated C++ source files
/fairseq/data/data_utils_fast.cpp
/fairseq/data/token_block_utils_fast.cpp
# VSCODE
.vscode/ftp-sync.json
.vscode/settings.json
......@@ -16,6 +16,7 @@ DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'attn',
'step',
'max_step',
'history'
])
......@@ -30,6 +31,7 @@ class IterativeRefinementGenerator(object):
decoding_format=None,
retain_dropout=False,
adaptive=True,
retain_history=False,
):
"""
Generates translations based on iterative refinement.
......@@ -53,6 +55,7 @@ class IterativeRefinementGenerator(object):
self.max_ratio = max_ratio
self.decoding_format = decoding_format
self.retain_dropout = retain_dropout
self.retain_history = retain_history
self.adaptive = adaptive
self.models = models
......@@ -123,6 +126,9 @@ class IterativeRefinementGenerator(object):
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out.output_tokens.clone()
if self.retain_history:
prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens])
finalized = [[] for _ in range(bsz)]
def is_a_loop(x, y, s, a):
......@@ -139,7 +145,12 @@ class IterativeRefinementGenerator(object):
def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
cutoff = prev_out_token.ne(self.pad)
tokens = prev_out_token[cutoff]
scores = prev_out_score[cutoff]
if prev_out_score is None:
scores, score = None, None
else:
scores = prev_out_score[cutoff]
score = scores.mean()
if prev_out_attn is None:
hypo_attn, alignment = None, None
else:
......@@ -149,7 +160,7 @@ class IterativeRefinementGenerator(object):
"steps": step,
"tokens": tokens,
"positional_scores": scores,
"score": scores.mean(),
"score": score,
"hypo_attn": hypo_attn,
"alignment": alignment,
}
......@@ -195,6 +206,9 @@ class IterativeRefinementGenerator(object):
None if decoder_out.attn is None else decoder_out.attn[terminated]
)
if self.retain_history:
finalized_history_tokens = [h[terminated] for h in decoder_out.history]
for i in range(finalized_idxs.size(0)):
finalized[finalized_idxs[i]] = [
finalized_hypos(
......@@ -204,6 +218,18 @@ class IterativeRefinementGenerator(object):
None if finalized_attn is None else finalized_attn[i],
)
]
if self.retain_history:
finalized[finalized_idxs[i]][0]['history'] = []
for j in range(len(finalized_history_tokens)):
finalized[finalized_idxs[i]][0]['history'].append(
finalized_hypos(
step,
finalized_history_tokens[j][i],
None, None
)
)
# check if all terminated
if terminated.sum() == terminated.size(0):
break
......@@ -214,6 +240,7 @@ class IterativeRefinementGenerator(object):
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
history=[h[not_terminated] for h in decoder_out.history] if decoder_out.history is not None else None
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]
......
......@@ -60,6 +60,7 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
history = decoder_out.history
# execute the decoder
output_masks = output_tokens.eq(self.unk)
......@@ -69,6 +70,9 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
if history is not None:
history.append(output_tokens.clone())
# skeptical decoding (depend on the maximum decoding steps.)
if (step + 1) < max_step:
skeptical_mask = _skeptical_unmasking(
......@@ -78,10 +82,14 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0)
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
history=history
)
......
......@@ -172,6 +172,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
history = decoder_out.history
# TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out
......@@ -188,10 +190,15 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
history=history
)
......
......@@ -410,6 +410,7 @@ class LevenshteinTransformerModel(TransformerModel):
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
attn = decoder_out.attn
history = decoder_out.history
bsz = output_tokens.size(0)
if max_ratio is None:
......@@ -446,6 +447,9 @@ class LevenshteinTransformerModel(TransformerModel):
output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.)
if history is not None:
history.append(output_tokens.clone())
# insert placeholders
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
......@@ -472,6 +476,9 @@ class LevenshteinTransformerModel(TransformerModel):
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
if history is not None:
history.append(output_tokens.clone())
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
if can_ins_word.sum() != 0:
......@@ -480,8 +487,6 @@ class LevenshteinTransformerModel(TransformerModel):
_skip_encoder_out(self.encoder, encoder_out, can_ins_word)
)
word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1)
word_ins_pred = word_ins_score.max(-1)[1]
_tokens, _scores = _apply_ins_words(
output_tokens[can_ins_word],
output_scores[can_ins_word],
......@@ -494,6 +499,9 @@ class LevenshteinTransformerModel(TransformerModel):
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
if history is not None:
history.append(output_tokens.clone())
# delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
......@@ -504,6 +512,7 @@ class LevenshteinTransformerModel(TransformerModel):
output_tokens=output_tokens,
output_scores=output_scores,
attn=attn,
history=history
)
def initialize_output_tokens(self, encoder_out, src_tokens):
......@@ -520,6 +529,7 @@ class LevenshteinTransformerModel(TransformerModel):
attn=None,
step=0,
max_step=0,
history=None
)
......
......@@ -117,6 +117,7 @@ class NATransformerModel(TransformerModel):
step = decoder_out.step
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
history = decoder_out.history
# execute the decoder
output_masks = output_tokens.ne(self.pad)
......@@ -127,12 +128,15 @@ class NATransformerModel(TransformerModel):
step=step,
)
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
if history is not None:
history.append(output_tokens.clone())
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
history=history
)
def initialize_output_tokens(self, encoder_out, src_tokens):
......@@ -160,6 +164,7 @@ class NATransformerModel(TransformerModel):
attn=None,
step=0,
max_step=0,
history=None
)
......
......@@ -506,6 +506,8 @@ def add_generation_args(parser):
help='maximum iterations for iterative refinement.')
group.add_argument('--iter-decode-force-max-iter', action='store_true',
help='if set, run exact the maximum number of iterations without early stop')
group.add_argument('--retain-iter-history', action='store_true',
help='if set, decoding returns the whole history of iterative refinement')
# special decoding format for advanced decoding.
group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs'])
......
......@@ -131,7 +131,8 @@ class TranslationLevenshteinTask(TranslationTask):
eos_penalty=getattr(args, 'iter_decode_eos_penalty', 0.0),
max_iter=getattr(args, 'iter_decode_max_iter', 10),
decoding_format=getattr(args, 'decoding_format', None),
adaptive=not getattr(args, 'iter_decode_force_max_iter', False))
adaptive=not getattr(args, 'iter_decode_force_max_iter', False),
retain_history=getattr(args, 'retain_iter_history', False))
def train_step(self,
sample,
......
......@@ -162,6 +162,15 @@ def main(args):
if args.print_step:
print('I-{}\t{}'.format(sample_id, hypo['steps']))
if getattr(args, 'retain_iter_history', False):
print("\n".join([
'E-{}_{}\t{}'.format(
sample_id, step,
utils.post_process_prediction(
h['tokens'].int().cpu(),
src_str, None, None, tgt_dict, None)[1])
for step, h in enumerate(hypo['history'])]))
# Score only the top hypothesis
if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment