Commit 27568a7e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge TracingCompliantTransformer and regular Transformer, fix NAT tests

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/899

Differential Revision: D18373060

Pulled By: myleott

fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
parent 2a9b4ec2
...@@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion): ...@@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
if masks is not None: if masks is not None:
outputs, targets = outputs[masks], targets[masks] outputs, targets = outputs[masks], targets[masks]
if not masks.any(): if masks is not None and not masks.any():
nll_loss = torch.tensor(0) nll_loss = torch.tensor(0)
loss = nll_loss loss = nll_loss
else: else:
......
...@@ -3,11 +3,20 @@ ...@@ -3,11 +3,20 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import namedtuple
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'output_tokens',
'output_scores',
'attn',
'step',
'max_step',
])
class IterativeRefinementGenerator(object): class IterativeRefinementGenerator(object):
...@@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object): ...@@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object):
@torch.no_grad() @torch.no_grad()
def generate(self, models, sample, prefix_tokens=None): def generate(self, models, sample, prefix_tokens=None):
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
if len(models) == 1: if len(models) == 1:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
...@@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object): ...@@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object):
# initialize buffers (very model specific, with length prediction or not) # initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out[0].clone() prev_output_tokens = prev_decoder_out.output_tokens.clone()
finalized = [[] for _ in range(bsz)] finalized = [[] for _ in range(bsz)]
...@@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object): ...@@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object):
"max_ratio": self.max_ratio, "max_ratio": self.max_ratio,
"decoding_format": self.decoding_format, "decoding_format": self.decoding_format,
} }
prev_decoder_out[3] = step prev_decoder_out = prev_decoder_out._replace(
prev_decoder_out[4] = self.max_iter + 1 step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder( decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options prev_decoder_out, encoder_out, **decoder_options
...@@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object): ...@@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object):
if self.adaptive: if self.adaptive:
# terminate if there is a loop # terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop( terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2] prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
) )
decoder_out[0] = out_tokens
decoder_out[1] = out_scores
decoder_out[2] = out_attn
else: else:
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool() terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
if step == self.max_iter: # reach last iteration, terminate if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1) terminated.fill_(1)
# collect finalized sentences # collect finalized sentences
finalized_idxs = sent_idxs[terminated] finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out[0][terminated] finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out[1][terminated] finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = ( finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated] None if decoder_out.attn is None else decoder_out.attn[terminated]
) )
for i in range(finalized_idxs.size(0)): for i in range(finalized_idxs.size(0)):
...@@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object): ...@@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object):
break break
# for next step # for next step
prev_decoder_out = _skip(decoder_out, ~terminated) not_terminated = ~terminated
encoder_out = script_skip_tensor_list(encoder_out, ~terminated) prev_decoder_out = decoder_out._replace(
sent_idxs = _skip(sent_idxs, ~terminated) output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out[0].clone() prev_output_tokens = prev_decoder_out.output_tokens.clone()
return finalized return finalized
...@@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al. ...@@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al.
arXiv preprint arXiv:1904.09324 (2019). arXiv preprint arXiv:1904.09324 (2019).
""" """
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel from fairseq.models.nonautoregressive_transformer import NATransformerModel
from fairseq.utils import new_arange
def _skeptical_unmasking(output_scores, output_masks, p): def _skeptical_unmasking(output_scores, output_masks, p):
...@@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel):
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"] step = decoder_out.step
max_step = decoder_out["max_step"] max_step = decoder_out.max_step
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out.output_tokens
output_scores = decoder_out["output_scores"] output_scores = decoder_out.output_scores
# execute the decoder # execute the decoder
output_masks = output_tokens.eq(self.unk) output_masks = output_tokens.eq(self.unk)
...@@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_fill_(skeptical_mask, self.unk) output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0) output_scores.masked_fill_(skeptical_mask, 0.0)
return {"output_tokens": output_tokens, "output_scores": output_scores} return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
@register_model_architecture("cmlm_transformer", "cmlm_transformer") @register_model_architecture("cmlm_transformer", "cmlm_transformer")
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import ( from fairseq.models.levenshtein_transformer import (
LevenshteinTransformerDecoder, LevenshteinTransformerDecoder,
...@@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import ( ...@@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import (
) )
from fairseq.models.transformer import Linear, TransformerModel from fairseq.models.transformer import Linear, TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import new_arange
class NegativeDistanceScore(object): class NegativeDistanceScore(object):
...@@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi ...@@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
@register_model("insertion_transformer") @register_model("insertion_transformer")
class InsertionTransformerModel(LevenshteinTransformerModel): class InsertionTransformerModel(LevenshteinTransformerModel):
def __init__(self, encoder, decoder): def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(args, encoder, decoder)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
): ):
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out.output_tokens
output_scores = decoder_out["output_scores"] output_scores = decoder_out.output_scores
# TODO: decoding for InsertionTransformer # TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins( word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out output_tokens, encoder_out=encoder_out
...@@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max() cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off] output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off] output_scores = output_scores[:, :cut_off]
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None} return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
class InsertionTransformerDecoder(LevenshteinTransformerDecoder): class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
...@@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder): ...@@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
self.label_tau = getattr(args, "label_tau", None) self.label_tau = getattr(args, "label_tau", None)
def forward_word_ins(self, prev_output_tokens, encoder_out=None): def forward_word_ins(self, prev_output_tokens, encoder_out=None):
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out) features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
features = self.pool_out( features = self.pool_out(
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel from fairseq.models.nonautoregressive_transformer import NATransformerModel
......
This diff is collapsed.
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, List from typing import List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask): ...@@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask):
return res return res
@torch.jit.script
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
outputs = {}
for s, t in x.items():
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
return outputs
def skip_tensors(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [skip_tensors(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: skip_tensors(v, mask) for k, v in x.items()}
raise NotImplementedError
@torch.jit.script @torch.jit.script
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
""" """
...@@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): ...@@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
@torch.jit.script @torch.jit.script
def fill_tensors(x, mask, y, padding_idx: int): def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
return x if x is not None else y
@torch.jit.script
def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]:
""" """
Filling tensor x with y at masked positions (dim=0). Filling tensor x with y at masked positions (dim=0).
""" """
if x is None or x.size()[0] == 0: if x is None or x.size()[0] == 0 or y is None:
return torch.empty([0]) return x
assert x.dim() == y.dim() and mask.size(0) == x.size(0) assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
...@@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int): ...@@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int):
else: else:
x[mask] = y x[mask] = y
return x return x
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
...@@ -3,11 +3,18 @@ ...@@ -3,11 +3,18 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip from fairseq.models.levenshtein_transformer import (
from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words _skip,
_apply_ins_masks,
_apply_ins_words,
_apply_del_words,
)
from fairseq.models.model_utils import fill_tensors as _fill
class BasicEnsembleModel(torch.nn.Module): class BasicEnsembleModel(torch.nn.Module):
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import ( from fairseq.models.transformer import (
Embedding, Embedding,
...@@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens): ...@@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens):
@register_model("nonautoregressive_transformer") @register_model("nonautoregressive_transformer")
class NATransformerModel(TransformerModel): class NATransformerModel(TransformerModel):
def __init__(self, encoder, decoder): def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos() self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos() self.eos = decoder.dictionary.eos()
...@@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel): ...@@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel):
return self.encoder(*encoder_inputs) return self.encoder(*encoder_inputs)
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"] step = decoder_out.step
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out.output_tokens
output_scores = decoder_out["output_scores"] output_scores = decoder_out.output_scores
# execute the decoder # execute the decoder
output_masks = output_tokens.ne(self.pad) output_masks = output_tokens.ne(self.pad)
...@@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel): ...@@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks]) output_scores.masked_scatter_(output_masks, _scores[output_masks])
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None} return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
def initialize_output_tokens(self, encoder_out, src_tokens): def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction # length prediction
_, length_tgt = self.decoder.forward_length_prediction(encoder_out) _, length_tgt = self.decoder.forward_length_prediction(encoder_out)
max_length = length_tgt.max() max_length = length_tgt.clamp_(min=2).max()
idx_length = utils.new_arange(src_tokens, max_length) idx_length = utils.new_arange(src_tokens, max_length)
initial_output_tokens = src_tokens.new_zeros( initial_output_tokens = src_tokens.new_zeros(
...@@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel): ...@@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel):
initial_output_scores = initial_output_tokens.new_zeros( initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size() *initial_output_tokens.size()
).type_as(encoder_out["encoder_out"]) ).type_as(encoder_out.encoder_out)
return { return DecoderOut(
"output_tokens": initial_output_tokens, output_tokens=initial_output_tokens,
"output_scores": initial_output_scores, output_scores=initial_output_scores,
"attn": None attn=None,
} step=0,
max_step=0,
)
class NATransformerDecoder(TransformerDecoder): class NATransformerDecoder(TransformerDecoder):
...@@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder): ...@@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder):
""" """
# embedding # embedding
if embedding_copy: if embedding_copy:
src_embd = encoder_out["encoder_embedding"] src_embd = encoder_out.encoder_embedding
src_mask = encoder_out["encoder_padding_mask"] src_mask = encoder_out.encoder_padding_mask
src_mask = ( src_mask = (
~src_mask ~src_mask
if src_mask is not None if src_mask is not None
...@@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder): ...@@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder):
x, attn = layer( x, attn = layer(
x, x,
encoder_out["encoder_out"] if encoder_out is not None else None, encoder_out.encoder_out if encoder_out is not None else None,
encoder_out["encoder_padding_mask"] encoder_out.encoder_padding_mask if encoder_out is not None else None,
if encoder_out is not None
else None,
self_attn_mask=None, self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask, self_attn_padding_mask=decoder_padding_mask,
) )
...@@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder): ...@@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder):
return copied_embedding return copied_embedding
def forward_length_prediction(self, encoder_out, tgt_tokens=None): def forward_length_prediction(self, encoder_out, tgt_tokens=None):
enc_feats = encoder_out["encoder_out"] # T x B x C enc_feats = encoder_out.encoder_out # T x B x C
src_masks = encoder_out["encoder_padding_mask"] # B x T or None src_masks = encoder_out.encoder_padding_mask # B x T or None
if self.pred_length_offset: if self.pred_length_offset:
if src_masks is None: if src_masks is None:
......
...@@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder): ...@@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder):
- a dictionary of additional data, where 'inner_states' - a dictionary of additional data, where 'inner_states'
is a list of hidden states. is a list of hidden states.
""" """
x, extra = self.extract_features(src_tokens, return_all_hiddens) x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
if not features_only: if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens) x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra return x, extra
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import namedtuple
import math import math
import torch import torch
...@@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel): ...@@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel):
return decoder_out return decoder_out
EncoderOut = namedtuple('TransformerEncoderOut', [
'encoder_out', # T x B x C
'encoder_padding_mask', # B x T
'encoder_embedding', # B x T x C
'encoder_states', # List[T x B x C]
])
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
""" """
Transformer encoder consisting of *args.encoder_layers* layers. Each layer Transformer encoder consisting of *args.encoder_layers* layers. Each layer
...@@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder): ...@@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder):
intermediate hidden states (default: False). intermediate hidden states (default: False).
Returns: Returns:
dict: namedtuple:
- **encoder_out** (Tensor): the last encoder layer's output of - **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)` shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of - **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)` padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate - **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`. hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True. Only populated if *return_all_hiddens* is True.
...@@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder): ...@@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder):
if return_all_hiddens: if return_all_hiddens:
encoder_states[-1] = x encoder_states[-1] = x
return { return EncoderOut(
'encoder_out': x, # T x B x C encoder_out=x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T encoder_padding_mask=encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C encoder_embedding=encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C] encoder_states=encoder_states, # List[T x B x C]
} )
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
""" """
...@@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder): ...@@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder):
Returns: Returns:
*encoder_out* rearranged according to *new_order* *encoder_out* rearranged according to *new_order*
""" """
if encoder_out['encoder_out'] is not None: if encoder_out.encoder_out is not None:
encoder_out['encoder_out'] = \ encoder_out = encoder_out._replace(
encoder_out['encoder_out'].index_select(1, new_order) encoder_out=encoder_out.encoder_out.index_select(1, new_order)
if encoder_out['encoder_padding_mask'] is not None: )
encoder_out['encoder_padding_mask'] = \ if encoder_out.encoder_padding_mask is not None:
encoder_out['encoder_padding_mask'].index_select(0, new_order) encoder_out = encoder_out._replace(
if encoder_out.get('encoder_states', None) is not None: encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
for idx, state in enumerate(encoder_out['encoder_states']): )
encoder_out['encoder_states'][idx] = state.index_select(1, new_order) if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out return encoder_out
def max_positions(self): def max_positions(self):
...@@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out=None, encoder_out=None,
incremental_state=None, incremental_state=None,
features_only=False, features_only=False,
**extra_args, **extra_args
): ):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing `(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for encoder_out (optional): output from the encoder, used for
encoder-side attention encoder-side attention
incremental_state (dict): dictionary used for storing state during incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding` :ref:`Incremental decoding`
...@@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- a dictionary with any model-specific outputs - a dictionary with any model-specific outputs
""" """
x, extra = self.extract_features( x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args, prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
**extra_args
) )
if not features_only: if not features_only:
x = self.output_layer(x) x = self.output_layer(x)
...@@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_state = None encoder_state = None
if encoder_out is not None: if encoder_out is not None:
if self.layer_wise_attention: if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx] encoder_state = encoder_out.encoder_states[idx]
else: else:
encoder_state = encoder_out['encoder_out'] encoder_state = encoder_out.encoder_out
if incremental_state is None and not full_context_alignment: if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x) self_attn_mask = self.buffered_future_mask(x)
...@@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x, layer_attn = layer( x, layer_attn = layer(
x, x,
encoder_state, encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None, encoder_out.encoder_padding_mask if encoder_out is not None else None,
incremental_state, incremental_state,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
......
...@@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module): ...@@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
def forward(self, encoder_out): def forward(self, encoder_out):
if not ( if not (
isinstance(encoder_out, dict) hasattr(encoder_out, 'encoder_out')
and 'encoder_out' in encoder_out and hasattr(encoder_out, 'encoder_padding_mask')
and 'encoder_padding_mask' in encoder_out and encoder_out.encoder_out.size(2) == self.embed_dim
and encoder_out['encoder_out'].size(2) == self.embed_dim
): ):
raise ValueError('Unexpected format for encoder_out') raise ValueError('Unexpected format for encoder_out')
# mean pooling over time # mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0 encoder_out[encoder_padding_mask] = 0
......
...@@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase): ...@@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase):
]) ])
generate_main(data_dir) generate_main(data_dir)
def test_cmlm_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'cmlm_transformer', [
'--apply-bert-init',
'--criterion', 'nat_loss',
'--noise', 'full_mask',
'--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_levenshtein_transformer(self): def test_levenshtein_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir: with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'levenshtein_transformer', [ train_translation_model(data_dir, 'levenshtein_transformer', [
'--apply-bert-init', '--early-exit', '6,6,6', '--apply-bert-init', '--early-exit', '6,6,6',
'--criterion', 'nat_loss' '--criterion', 'nat_loss'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_nonautoregressive_transformer(self): def test_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'nonautoregressive_transformer', [ train_translation_model(data_dir, 'nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion', '--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--pred-length-offset', 'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
'--length-loss-factor', '0.1' '--length-loss-factor', '0.1'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_iterative_nonautoregressive_transformer(self): def test_iterative_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir: with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [ train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion', '--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--stochastic-approx', 'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
'--dae-ratio', '0.5', '--train-step', '3' '--dae-ratio', '0.5', '--train-step', '3'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_insertion_transformer(self): def test_insertion_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir: with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'insertion_transformer', [ train_translation_model(data_dir, 'insertion_transformer', [
'--apply-bert-init', '--criterion', 'nat_loss', '--noise', '--apply-bert-init', '--criterion', 'nat_loss', '--noise',
'random_mask' 'random_mask'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_mixture_of_experts(self): def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment