Commit fb3e1e36 authored by Changhan Wang's avatar Changhan Wang Committed by Facebook Github Bot
Browse files

update LevT ensemble

Summary: Update LevT ensemble class with the recent API changes in LevT and iterative decoder classes.

Reviewed By: jhcross

Differential Revision: D18689292

fbshipit-source-id: 64d4cdb6513a32a32d49e0ebf57886ae576722d4
parent 5349052a
......@@ -9,12 +9,25 @@ import torch
import torch.nn.functional as F
from fairseq.models.levenshtein_transformer import (
_fill,
_skip,
_skip_encoder_out,
_apply_ins_masks,
_apply_ins_words,
_apply_del_words,
)
from fairseq.models.model_utils import fill_tensors as _fill
class _EnsembleModelEncoder(object):
def __init__(self, models):
self.models = models
def reorder_encoder_out(self, encoder_outs, new_order):
encoder_outs = [
model.encoder.reorder_encoder_out(encoder_out, new_order)
for model, encoder_out in zip(self.models, encoder_outs)
]
return encoder_outs
class BasicEnsembleModel(torch.nn.Module):
......@@ -27,6 +40,7 @@ class BasicEnsembleModel(torch.nn.Module):
self.eos = self.models[0].decoder.dictionary.eos()
self.pad = self.models[0].decoder.dictionary.pad()
self.unk = self.models[0].decoder.dictionary.unk()
self.encoder = _EnsembleModelEncoder(self.models)
def has_encoder(self):
return hasattr(self.models[0], 'encoder')
......@@ -60,18 +74,18 @@ class EnsembleLevT(BasicEnsembleModel):
# A pipeline of three steps: deletion, placeholder, and word insertion.
# We need to average scores in each step in a pipeline way because of dependence.
# deletion
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
attn = decoder_out["attn"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
attn = decoder_out.attn
bsz = output_tokens.size(0)
if max_ratio is None:
max_lens = output_tokens.new().fill_(255)
else:
if encoder_outs[0]["encoder_padding_mask"] is None:
src_lens = encoder_outs[0]["encoder_out"].new(bsz).fill_(encoder_outs[0]["encoder_out"].size(1))
if encoder_outs[0].encoder_padding_mask is None:
src_lens = encoder_outs[0].encoder_out.new(bsz).fill_(encoder_outs[0].encoder_out.size(1))
else:
src_lens = (~encoder_outs[0]["encoder_padding_mask"]).sum(1)
src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words
......@@ -114,11 +128,12 @@ class EnsembleLevT(BasicEnsembleModel):
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
attn = None if attn is None else attn[:, :cut_off, :]
return {
"output_tokens": output_tokens,
"output_scores": output_scores,
"attn": attn,
}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=attn,
history=None
)
def forward_word_del(self, encoder_outs, output_tokens, output_scores, attn, can_del_word):
word_del_score_avg = []
......@@ -126,7 +141,7 @@ class EnsembleLevT(BasicEnsembleModel):
for model, encoder_out in zip(self.models, encoder_outs):
word_del_out, word_del_attn = model.decoder.forward_word_del(
_skip(output_tokens, can_del_word),
_skip(encoder_out, can_del_word),
_skip_encoder_out(model.encoder, encoder_out, can_del_word),
)
word_del_score = F.log_softmax(word_del_out, 2)
word_del_score_avg.append(word_del_score)
......@@ -157,7 +172,7 @@ class EnsembleLevT(BasicEnsembleModel):
for model, encoder_out in zip(self.models, encoder_outs):
mask_ins_out, _ = model.decoder.forward_mask_ins(
_skip(output_tokens, can_ins_mask),
_skip(encoder_out, can_ins_mask),
_skip_encoder_out(model.encoder, encoder_out, can_ins_mask),
)
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
......@@ -186,7 +201,7 @@ class EnsembleLevT(BasicEnsembleModel):
for model, encoder_out in zip(self.models, encoder_outs):
word_ins_out, word_ins_attn = model.decoder.forward_word_ins(
_skip(output_tokens, can_ins_word),
_skip(encoder_out, can_ins_word),
_skip_encoder_out(model.encoder, encoder_out, can_ins_word),
)
word_ins_score = F.log_softmax(word_ins_out, 2)
word_ins_score_avg.append(word_ins_score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment