Commit 27568a7e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge TracingCompliantTransformer and regular Transformer, fix NAT tests

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/899

Differential Revision: D18373060

Pulled By: myleott

fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
parent 2a9b4ec2
......@@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
if masks is not None:
outputs, targets = outputs[masks], targets[masks]
if not masks.any():
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
......
......@@ -3,11 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import torch
from fairseq import utils
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'output_tokens',
'output_scores',
'attn',
'step',
'max_step',
])
class IterativeRefinementGenerator(object):
......@@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object):
@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None):
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
if len(models) == 1:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
......@@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object):
# initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()
finalized = [[] for _ in range(bsz)]
......@@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object):
"max_ratio": self.max_ratio,
"decoding_format": self.decoding_format,
}
prev_decoder_out[3] = step
prev_decoder_out[4] = self.max_iter + 1
prev_decoder_out = prev_decoder_out._replace(
step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options
......@@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object):
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
)
decoder_out[0] = out_tokens
decoder_out[1] = out_scores
decoder_out[2] = out_attn
else:
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out[0][terminated]
finalized_scores = decoder_out[1][terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated]
None if decoder_out.attn is None else decoder_out.attn[terminated]
)
for i in range(finalized_idxs.size(0)):
......@@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object):
break
# for next step
prev_decoder_out = _skip(decoder_out, ~terminated)
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
sent_idxs = _skip(sent_idxs, ~terminated)
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()
return finalized
......@@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al.
arXiv preprint arXiv:1904.09324 (2019).
"""
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel
from fairseq.utils import new_arange
def _skeptical_unmasking(output_scores, output_masks, p):
......@@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel):
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"]
max_step = decoder_out["max_step"]
step = decoder_out.step
max_step = decoder_out.max_step
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# execute the decoder
output_masks = output_tokens.eq(self.unk)
......@@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0)
return {"output_tokens": output_tokens, "output_scores": output_scores}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
@register_model_architecture("cmlm_transformer", "cmlm_transformer")
......
......@@ -6,7 +6,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import (
LevenshteinTransformerDecoder,
......@@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import (
)
from fairseq.models.transformer import Linear, TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import new_arange
class NegativeDistanceScore(object):
......@@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
@register_model("insertion_transformer")
class InsertionTransformerModel(LevenshteinTransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
@staticmethod
def add_args(parser):
......@@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
):
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out
......@@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
......@@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
self.label_tau = getattr(args, "label_tau", None)
def forward_word_ins(self, prev_output_tokens, encoder_out=None):
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out)
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
features = self.pool_out(
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
)
......
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel
......
This diff is collapsed.
......@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
from typing import List, Optional
import torch
from torch import Tensor
......@@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask):
return res
@torch.jit.script
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
outputs = {}
for s, t in x.items():
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
return outputs
def skip_tensors(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [skip_tensors(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: skip_tensors(v, mask) for k, v in x.items()}
raise NotImplementedError
@torch.jit.script
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
"""
......@@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
@torch.jit.script
def fill_tensors(x, mask, y, padding_idx: int):
def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
return x if x is not None else y
@torch.jit.script
def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]:
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None or x.size()[0] == 0:
return torch.empty([0])
if x is None or x.size()[0] == 0 or y is None:
return x
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
......@@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int):
else:
x[mask] = y
return x
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
......@@ -3,11 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words
from fairseq.models.levenshtein_transformer import (
_skip,
_apply_ins_masks,
_apply_ins_words,
_apply_del_words,
)
from fairseq.models.model_utils import fill_tensors as _fill
class BasicEnsembleModel(torch.nn.Module):
......
......@@ -5,7 +5,9 @@
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
Embedding,
......@@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens):
@register_model("nonautoregressive_transformer")
class NATransformerModel(TransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos()
......@@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel):
return self.encoder(*encoder_inputs)
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"]
output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
step = decoder_out.step
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# execute the decoder
output_masks = output_tokens.ne(self.pad)
......@@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction
_, length_tgt = self.decoder.forward_length_prediction(encoder_out)
max_length = length_tgt.max()
max_length = length_tgt.clamp_(min=2).max()
idx_length = utils.new_arange(src_tokens, max_length)
initial_output_tokens = src_tokens.new_zeros(
......@@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel):
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(encoder_out["encoder_out"])
return {
"output_tokens": initial_output_tokens,
"output_scores": initial_output_scores,
"attn": None
}
).type_as(encoder_out.encoder_out)
return DecoderOut(
output_tokens=initial_output_tokens,
output_scores=initial_output_scores,
attn=None,
step=0,
max_step=0,
)
class NATransformerDecoder(TransformerDecoder):
......@@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder):
"""
# embedding
if embedding_copy:
src_embd = encoder_out["encoder_embedding"]
src_mask = encoder_out["encoder_padding_mask"]
src_embd = encoder_out.encoder_embedding
src_mask = encoder_out.encoder_padding_mask
src_mask = (
~src_mask
if src_mask is not None
......@@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder):
x, attn = layer(
x,
encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"]
if encoder_out is not None
else None,
encoder_out.encoder_out if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask,
)
......@@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder):
return copied_embedding
def forward_length_prediction(self, encoder_out, tgt_tokens=None):
enc_feats = encoder_out["encoder_out"] # T x B x C
src_masks = encoder_out["encoder_padding_mask"] # B x T or None
enc_feats = encoder_out.encoder_out # T x B x C
src_masks = encoder_out.encoder_padding_mask # B x T or None
if self.pred_length_offset:
if src_masks is None:
......
......@@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder):
- a dictionary of additional data, where 'inner_states'
is a list of hidden states.
"""
x, extra = self.extract_features(src_tokens, return_all_hiddens)
x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
......
This diff is collapsed.
......@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
import math
import torch
......@@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel):
return decoder_out
EncoderOut = namedtuple('TransformerEncoderOut', [
'encoder_out', # T x B x C
'encoder_padding_mask', # B x T
'encoder_embedding', # B x T x C
'encoder_states', # List[T x B x C]
])
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
......@@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder):
intermediate hidden states (default: False).
Returns:
dict:
namedtuple:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
......@@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder):
if return_all_hiddens:
encoder_states[-1] = x
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C]
}
return EncoderOut(
encoder_out=x, # T x B x C
encoder_padding_mask=encoder_padding_mask, # B x T
encoder_embedding=encoder_embedding, # B x T x C
encoder_states=encoder_states, # List[T x B x C]
)
def reorder_encoder_out(self, encoder_out, new_order):
"""
......@@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder):
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = \
encoder_out['encoder_out'].index_select(1, new_order)
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
if encoder_out.get('encoder_states', None) is not None:
for idx, state in enumerate(encoder_out['encoder_states']):
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
if encoder_out.encoder_out is not None:
encoder_out = encoder_out._replace(
encoder_out=encoder_out.encoder_out.index_select(1, new_order)
)
if encoder_out.encoder_padding_mask is not None:
encoder_out = encoder_out._replace(
encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
)
if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
......@@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
**extra_args
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
......@@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
**extra_args
)
if not features_only:
x = self.output_layer(x)
......@@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx]
encoder_state = encoder_out.encoder_states[idx]
else:
encoder_state = encoder_out['encoder_out']
encoder_state = encoder_out.encoder_out
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
......@@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
......
......@@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
def forward(self, encoder_out):
if not (
isinstance(encoder_out, dict)
and 'encoder_out' in encoder_out
and 'encoder_padding_mask' in encoder_out
and encoder_out['encoder_out'].size(2) == self.embed_dim
hasattr(encoder_out, 'encoder_out')
and hasattr(encoder_out, 'encoder_padding_mask')
and encoder_out.encoder_out.size(2) == self.embed_dim
):
raise ValueError('Unexpected format for encoder_out')
# mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C
encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
......
......@@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase):
])
generate_main(data_dir)
def test_cmlm_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'cmlm_transformer', [
'--apply-bert-init',
'--criterion', 'nat_loss',
'--noise', 'full_mask',
'--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_levenshtein_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'levenshtein_transformer', [
'--apply-bert-init', '--early-exit', '6,6,6',
'--criterion', 'nat_loss'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_iterative_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
'--dae-ratio', '0.5', '--train-step', '3'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_insertion_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'insertion_transformer', [
'--apply-bert-init', '--criterion', 'nat_loss', '--noise',
'random_mask'
], task='translation_lev')
generate_main(data_dir)
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment