Commit 34e79c58 authored by Jungo Kasai's avatar Jungo Kasai Committed by Facebook Github Bot
Browse files

ensemble levts

Summary:
Add ensemble wrappers to the levenshtein NAT.
Levenshtein
Final softmax ensemble over the pipeline of three steps: deletion, placeholder insertion, and word selection.
1. Deletion
2. Placeholder Insertion
3. Word Selection

Each step involves scoring, averaging the scores over the ensemble, and then make hard decisions with argmax. Then next step follows. We cannot do the three steps in parallel by design.

Reviewed By: kahne

Differential Revision: D17723202

fbshipit-source-id: 05f7a4fcd922a972cc4796ca397e8220f0b4d53e
parent c2165224
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
import torch import torch
from fairseq.models.model_utils import skip_tensors as _skip from fairseq.models.model_utils import skip_tensors as _skip
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
class IterativeRefinementGenerator(object): class IterativeRefinementGenerator(object):
...@@ -44,9 +46,14 @@ class IterativeRefinementGenerator(object): ...@@ -44,9 +46,14 @@ class IterativeRefinementGenerator(object):
@torch.no_grad() @torch.no_grad()
def generate(self, models, sample, prefix_tokens=None): def generate(self, models, sample, prefix_tokens=None):
# TODO: model ensemble if len(models) == 1:
assert len(models) == 1, 'only support single model' # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
model = models[0] model = models[0]
elif isinstance(models[0], LevenshteinTransformerModel):
model = EnsembleLevT(models)
else:
raise NotImplementedError
if not self.retain_dropout: if not self.retain_dropout:
model.eval() model.eval()
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
from fairseq.models.levenshtein_transformer import _apply_del_words, _apply_ins_masks, _apply_ins_words
class BasicEnsembleModel(torch.nn.Module):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__()
self.models = torch.nn.ModuleList(models)
self.bos = self.models[0].decoder.dictionary.bos()
self.eos = self.models[0].decoder.dictionary.eos()
self.pad = self.models[0].decoder.dictionary.pad()
self.unk = self.models[0].decoder.dictionary.unk()
def has_encoder(self):
return hasattr(self.models[0], 'encoder')
def max_decoder_positions(self):
return min(m.max_decoder_positions() for m in self.models)
@torch.no_grad()
def forward_encoder(self, encoder_input):
if not self.has_encoder():
return None
return [model.forward_encoder(encoder_input) for model in self.models]
@torch.no_grad()
def forward_decoder(self, *inputs):
raise NotImplementedError
def initialize_output_tokens(self, *inputs):
raise NotImplementedError
class EnsembleLevT(BasicEnsembleModel):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__(models)
@torch.no_grad()
def forward_decoder(self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs):
# LevT ensembling
# A pipeline of three steps: deletion, placeholder, and word insertion.
# We need to average scores in each step in a pipeline way because of dependence.
# deletion
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
attn = decoder_out["attn"]
bsz = output_tokens.size(0)
if max_ratio is None:
max_lens = output_tokens.new().fill_(255)
else:
if encoder_outs[0]["encoder_padding_mask"] is None:
src_lens = encoder_outs[0]["encoder_out"].new(bsz).fill_(encoder_outs[0]["encoder_out"].size(1))
else:
src_lens = (~encoder_outs[0]["encoder_padding_mask"]).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words
# do not delete tokens if it is <s> </s>
can_del_word = output_tokens.ne(self.pad).sum(1) > 2
if can_del_word.sum() != 0: # we cannot delete, skip
output_tokens, output_scores, attn = self.forward_word_del(
encoder_outs,
output_tokens,
output_scores,
attn,
can_del_word,
)
# insert placeholders
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
output_tokens, output_scores = self.forward_mask_ins(
encoder_outs,
output_tokens,
output_scores,
can_ins_mask,
eos_penalty,
max_lens,
)
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
if can_ins_word.sum() != 0:
output_tokens, output_scores, attn = self.forward_word_ins(
encoder_outs,
output_tokens,
output_scores,
attn,
can_ins_word,
)
# delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
attn = None if attn is None else attn[:, :cut_off, :]
return {
"output_tokens": output_tokens,
"output_scores": output_scores,
"attn": attn,
}
def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can_del_word):
word_del_score_avg = []
word_del_attn_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
word_del_out, word_del_attn = model.decoder.forward_word_del(
_skip(output_tokens, can_del_word),
_skip(encoder_out, can_del_word),
)
word_del_score = F.log_softmax(word_del_out, 2)
word_del_score_avg.append(word_del_score)
word_del_attn_avg.append(word_del_attn)
word_del_score_avg = torch.logsumexp(torch.stack(word_del_score_avg, dim=0), dim=0) - math.log(len(self.models))
word_del_pred = word_del_score_avg.max(-1)[1].bool()
if word_del_attn_avg[0] is not None:
word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0)/len(self.models)
else:
word_del_attn_avg = None
_tokens, _scores, _attn = _apply_del_words(
output_tokens[can_del_word],
output_scores[can_del_word],
word_del_attn_avg,
word_del_pred,
self.pad,
self.bos,
self.eos,
)
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.)
return output_tokens, output_scores, attn
def forward_mask_ins(self, encoder_outs, output_tokens, output_scores, can_ins_mask, eos_penalty, max_lens):
mask_ins_score_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
mask_ins_out, _ = model.decoder.forward_mask_ins(
_skip(output_tokens, can_ins_mask),
_skip(encoder_out, can_ins_mask),
)
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_score_avg.append(mask_ins_score)
mask_ins_score_avg = torch.logsumexp(torch.stack(mask_ins_score_avg, dim=0), dim=0) - math.log(len(self.models))
mask_ins_pred = mask_ins_score_avg.max(-1)[1]
mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
)
_tokens, _scores = _apply_ins_masks(
output_tokens[can_ins_mask],
output_scores[can_ins_mask],
mask_ins_pred,
self.pad,
self.unk,
self.eos,
)
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
return output_tokens, output_scores
def forward_word_ins(self, encoder_outs, output_tokens, output_scores, attn, can_ins_word):
word_ins_score_avg = []
word_ins_attn_avg = []
for model, encoder_out in zip(self.models, encoder_outs):
word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
_skip(output_tokens, can_ins_word),
_skip(encoder_out, can_ins_word),
)
word_ins_score = F.log_softmax(word_ins_out, 2)
word_ins_score_avg.append(word_ins_score)
word_ins_attn_avg.append(word_ins_attn)
word_ins_score_avg = torch.logsumexp(torch.stack(word_ins_score_avg, dim=0), dim=0) - math.log(len(self.models))
if word_ins_attn_avg[0] is not None:
word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0)/len(self.models)
else:
word_ins_attn_avg = None
word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1)
_tokens, _scores = _apply_ins_words(
output_tokens[can_ins_word],
output_scores[can_ins_word],
word_ins_pred,
word_ins_score_max,
self.unk,
)
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_ins_word, _scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
return output_tokens, output_scores, attn
def initialize_output_tokens(self, encoder_outs, src_tokens):
# LevT doesn't do length prediction.
return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment