Commit 27568a7e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge TracingCompliantTransformer and regular Transformer, fix NAT tests

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/899

Differential Revision: D18373060

Pulled By: myleott

fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
parent 2a9b4ec2
...@@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion): ...@@ -48,7 +48,7 @@ class LabelSmoothedDualImitationCriterion(FairseqCriterion):
if masks is not None: if masks is not None:
outputs, targets = outputs[masks], targets[masks] outputs, targets = outputs[masks], targets[masks]
if not masks.any(): if masks is not None and not masks.any():
nll_loss = torch.tensor(0) nll_loss = torch.tensor(0)
loss = nll_loss loss = nll_loss
else: else:
......
...@@ -3,11 +3,20 @@ ...@@ -3,11 +3,20 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import namedtuple
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'output_tokens',
'output_scores',
'attn',
'step',
'max_step',
])
class IterativeRefinementGenerator(object): class IterativeRefinementGenerator(object):
...@@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object): ...@@ -88,6 +97,8 @@ class IterativeRefinementGenerator(object):
@torch.no_grad() @torch.no_grad()
def generate(self, models, sample, prefix_tokens=None): def generate(self, models, sample, prefix_tokens=None):
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT
if len(models) == 1: if len(models) == 1:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
...@@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object): ...@@ -110,7 +121,7 @@ class IterativeRefinementGenerator(object):
# initialize buffers (very model specific, with length prediction or not) # initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out[0].clone() prev_output_tokens = prev_decoder_out.output_tokens.clone()
finalized = [[] for _ in range(bsz)] finalized = [[] for _ in range(bsz)]
...@@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object): ...@@ -150,8 +161,10 @@ class IterativeRefinementGenerator(object):
"max_ratio": self.max_ratio, "max_ratio": self.max_ratio,
"decoding_format": self.decoding_format, "decoding_format": self.decoding_format,
} }
prev_decoder_out[3] = step prev_decoder_out = prev_decoder_out._replace(
prev_decoder_out[4] = self.max_iter + 1 step=step,
max_step=self.max_iter + 1,
)
decoder_out = model.forward_decoder( decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options prev_decoder_out, encoder_out, **decoder_options
...@@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object): ...@@ -160,24 +173,26 @@ class IterativeRefinementGenerator(object):
if self.adaptive: if self.adaptive:
# terminate if there is a loop # terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop( terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2] prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
) )
decoder_out[0] = out_tokens
decoder_out[1] = out_scores
decoder_out[2] = out_attn
else: else:
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool() terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()
if step == self.max_iter: # reach last iteration, terminate if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1) terminated.fill_(1)
# collect finalized sentences # collect finalized sentences
finalized_idxs = sent_idxs[terminated] finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out[0][terminated] finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out[1][terminated] finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = ( finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated] None if decoder_out.attn is None else decoder_out.attn[terminated]
) )
for i in range(finalized_idxs.size(0)): for i in range(finalized_idxs.size(0)):
...@@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object): ...@@ -194,10 +209,15 @@ class IterativeRefinementGenerator(object):
break break
# for next step # for next step
prev_decoder_out = _skip(decoder_out, ~terminated) not_terminated = ~terminated
encoder_out = script_skip_tensor_list(encoder_out, ~terminated) prev_decoder_out = decoder_out._replace(
sent_idxs = _skip(sent_idxs, ~terminated) output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]
prev_output_tokens = prev_decoder_out[0].clone() prev_output_tokens = prev_decoder_out.output_tokens.clone()
return finalized return finalized
...@@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al. ...@@ -10,9 +10,9 @@ Ghazvininejad, Marjan, et al.
arXiv preprint arXiv:1904.09324 (2019). arXiv preprint arXiv:1904.09324 (2019).
""" """
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel from fairseq.models.nonautoregressive_transformer import NATransformerModel
from fairseq.utils import new_arange
def _skeptical_unmasking(output_scores, output_masks, p): def _skeptical_unmasking(output_scores, output_masks, p):
...@@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -55,11 +55,11 @@ class CMLMNATransformerModel(NATransformerModel):
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"] step = decoder_out.step
max_step = decoder_out["max_step"] max_step = decoder_out.max_step
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out.output_tokens
output_scores = decoder_out["output_scores"] output_scores = decoder_out.output_scores
# execute the decoder # execute the decoder
output_masks = output_tokens.eq(self.unk) output_masks = output_tokens.eq(self.unk)
...@@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel): ...@@ -78,7 +78,11 @@ class CMLMNATransformerModel(NATransformerModel):
output_tokens.masked_fill_(skeptical_mask, self.unk) output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0) output_scores.masked_fill_(skeptical_mask, 0.0)
return {"output_tokens": output_tokens, "output_scores": output_scores} return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
@register_model_architecture("cmlm_transformer", "cmlm_transformer") @register_model_architecture("cmlm_transformer", "cmlm_transformer")
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import ( from fairseq.models.levenshtein_transformer import (
LevenshteinTransformerDecoder, LevenshteinTransformerDecoder,
...@@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import ( ...@@ -14,6 +14,7 @@ from fairseq.models.levenshtein_transformer import (
) )
from fairseq.models.transformer import Linear, TransformerModel from fairseq.models.transformer import Linear, TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import new_arange
class NegativeDistanceScore(object): class NegativeDistanceScore(object):
...@@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi ...@@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
@register_model("insertion_transformer") @register_model("insertion_transformer")
class InsertionTransformerModel(LevenshteinTransformerModel): class InsertionTransformerModel(LevenshteinTransformerModel):
def __init__(self, encoder, decoder): def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(args, encoder, decoder)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -169,8 +170,8 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
): ):
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out.output_tokens
output_scores = decoder_out["output_scores"] output_scores = decoder_out.output_scores
# TODO: decoding for InsertionTransformer # TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins( word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out output_tokens, encoder_out=encoder_out
...@@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel): ...@@ -187,7 +188,11 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max() cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off] output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off] output_scores = output_scores[:, :cut_off]
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None} return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
class InsertionTransformerDecoder(LevenshteinTransformerDecoder): class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
...@@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder): ...@@ -206,7 +211,7 @@ class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
self.label_tau = getattr(args, "label_tau", None) self.label_tau = getattr(args, "label_tau", None)
def forward_word_ins(self, prev_output_tokens, encoder_out=None): def forward_word_ins(self, prev_output_tokens, encoder_out=None):
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out) features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
features = self.pool_out( features = self.pool_out(
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel from fairseq.models.nonautoregressive_transformer import NATransformerModel
......
#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights # LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from __future__ import absolute_import, division, print_function, unicode_literals
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.tracing_compliant_transformer import ( from fairseq.models.transformer import (
TracingTransformerDecoder, Embedding,
TracingTransformerEncoder, TransformerDecoder,
TracingTransformerModel, TransformerEncoder,
TransformerDecoderLayer, TransformerModel,
) TransformerDecoderLayer
from fairseq.models.model_utils import (
fill_tensors as _fill,
script_skip_tensor,
script_skip_tensor_list,
) )
from fairseq.models.transformer import Embedding
from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.modules.transformer_sentence_encoder import init_bert_params
from torch import Tensor from fairseq.utils import new_arange
# -------------- Helper Functions --------------------------------------------------- #
def _skip(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [_skip(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: _skip(v, mask) for k, v in x.items()}
raise NotImplementedError
def _skip_encoder_out(encoder, encoder_out, mask):
if not mask.any():
return encoder_out
else:
return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze())
def _fill(x, mask, y, padding_idx):
"""
Filling tensor x with y at masked positions (dim=0).
"""
if x is None:
return y
assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
n_selected = mask.sum()
assert n_selected == y.size(0)
if n_selected == x.size(0):
return y
if x.size(1) < y.size(1):
dims = [x.size(0), y.size(1) - x.size(1)]
if x.dim() == 3:
dims.append(x.size(2))
x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1)
x[mask] = y
elif x.size(1) > y.size(1):
x[mask] = padding_idx
if x.dim() == 2:
x[mask, :y.size(1)] = y
else:
x[mask, :y.size(1), :] = y
else:
x[mask] = y
return x
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): def load_libnat():
try: try:
from fairseq import libnat from fairseq import libnat
except ImportError as e: except ImportError as e:
import sys import sys
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e raise e
return libnat
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
libnat = load_libnat()
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
in_tokens_list = [ in_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
] ]
out_tokens_list = [ out_tokens_list = [
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) [t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
] ]
full_labels = libnat.suggested_ed2_path( full_labels = libnat.suggested_ed2_path(
...@@ -71,28 +128,27 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): ...@@ -71,28 +128,27 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
] ]
# transform to tensor # transform to tensor
masked_tgt_masks = torch.tensor(masked_tgt_masks, device=out_tokens.device).bool() masked_tgt_masks = torch.tensor(
masked_tgt_masks, device=out_tokens.device
).bool()
mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx)
return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets
def _get_del_targets(in_tokens, out_tokens, padding_idx): def _get_del_targets(in_tokens, out_tokens, padding_idx):
try: libnat = load_libnat()
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
out_seq_len = out_tokens.size(1) out_seq_len = out_tokens.size(1)
in_tokens_list = [ with torch.cuda.device_of(in_tokens):
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) in_tokens_list = [
] [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
out_tokens_list = [ ]
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) out_tokens_list = [
] [t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path( full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx in_tokens_list, out_tokens_list, padding_idx
...@@ -104,26 +160,23 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx): ...@@ -104,26 +160,23 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
] ]
# transform to tensor # transform to tensor
word_del_targets = torch.tensor(word_del_targets) word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets return word_del_targets
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
try: libnat = load_libnat()
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n")
raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
in_tokens_list = [ with torch.cuda.device_of(in_tokens):
[t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) in_tokens_list = [
] [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist())
out_tokens_list = [ ]
[t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) out_tokens_list = [
] [t for t in s if t != padding_idx]
for i, s in enumerate(out_tokens.tolist())
]
full_labels = libnat.suggested_ed2_path( full_labels = libnat.suggested_ed2_path(
in_tokens_list, out_tokens_list, padding_idx in_tokens_list, out_tokens_list, padding_idx
...@@ -144,15 +197,101 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): ...@@ -144,15 +197,101 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
] ]
# transform to tensor # transform to tensor
mask_ins_targets = torch.tensor(mask_ins_targets) mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device)
word_del_targets = torch.tensor(word_del_targets) word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device)
return word_del_targets, mask_ins_targets return word_del_targets, mask_ins_targets
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
new_arange(out_lengths, out_max_len)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(
in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx
):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
new_arange(in_tokens)
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
# ------------------------------------------------------------------------------------- #
@register_model("levenshtein_transformer") @register_model("levenshtein_transformer")
class LevenshteinTransformerModel(TracingTransformerModel): class LevenshteinTransformerModel(TransformerModel):
def __init__(self, encoder, decoder): def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos() self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos() self.eos = decoder.dictionary.eos()
...@@ -161,7 +300,7 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -161,7 +300,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
TracingTransformerModel.add_args(parser) TransformerModel.add_args(parser)
parser.add_argument( parser.add_argument(
"--apply-bert-init", "--apply-bert-init",
action="store_true", action="store_true",
...@@ -171,31 +310,27 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -171,31 +310,27 @@ class LevenshteinTransformerModel(TracingTransformerModel):
"--early-exit", "--early-exit",
default="6,6,6", default="6,6,6",
type=str, type=str,
help="number of decoder layers for del_word, ins_mask, ins_word", help="number of decoder layers before word_del, mask_ins, word_ins",
) )
parser.add_argument( parser.add_argument(
"--no-share-discriminator", "--no-share-discriminator",
action="store_true", action="store_true",
help="addtional decoder-layers to learn deletion", help="separate parameters for discriminator",
) )
parser.add_argument( parser.add_argument(
"--no-share-maskpredictor", "--no-share-maskpredictor",
action="store_true", action="store_true",
help="addtional decoder-layers to learn predicting masks", help="separate parameters for mask-predictor",
) )
parser.add_argument( parser.add_argument(
"--sampling-for-deletion", "--share-discriminator-maskpredictor",
action="store_true", action="store_true",
help="instead of argmax, use sampling to predict the tokens", help="share the parameters for both mask-predictor and discriminator",
) )
# Added for compatibility
parser.add_argument( parser.add_argument(
"--decoder-out-embed-dim", "--sampling-for-deletion",
default=None, action='store_true',
type=int, help='instead of argmax, use sampling to predict the tokens'
metavar="N",
help="decoder output embedding dimension (bottleneck layer before"
"output layer if specified.)",
) )
@classmethod @classmethod
...@@ -207,7 +342,7 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -207,7 +342,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
@classmethod @classmethod
def build_encoder(cls, args, src_dict, embed_tokens): def build_encoder(cls, args, src_dict, embed_tokens):
encoder = TracingTransformerEncoder(args, src_dict, embed_tokens) encoder = TransformerEncoder(args, src_dict, embed_tokens)
if getattr(args, "apply_bert_init", False): if getattr(args, "apply_bert_init", False):
encoder.apply(init_bert_params) encoder.apply(init_bert_params)
return encoder return encoder
...@@ -238,8 +373,8 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -238,8 +373,8 @@ class LevenshteinTransformerModel(TracingTransformerModel):
# make online prediction # make online prediction
if self.decoder.sampling_for_deletion: if self.decoder.sampling_for_deletion:
word_predictions = torch.multinomial( word_predictions = torch.multinomial(
F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1 F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view(
).view(word_ins_out.size(0), -1) word_ins_out.size(0), -1)
else: else:
word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1]
...@@ -249,7 +384,10 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -249,7 +384,10 @@ class LevenshteinTransformerModel(TracingTransformerModel):
# generate training labels for deletion # generate training labels for deletion
word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad) word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad)
word_del_out, _ = self.decoder.forward_word_del(word_predictions, encoder_out) word_del_out, _ = self.decoder.forward_word_del(
word_predictions, encoder_out)
word_del_masks = word_predictions.ne(self.pad)
return { return {
"mask_ins_out": mask_ins_out, "mask_ins_out": mask_ins_out,
"mask_ins_tgt": mask_ins_targets, "mask_ins_tgt": mask_ins_targets,
...@@ -259,7 +397,7 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -259,7 +397,7 @@ class LevenshteinTransformerModel(TracingTransformerModel):
"word_ins_mask": masked_tgt_masks, "word_ins_mask": masked_tgt_masks,
"word_del_out": word_del_out, "word_del_out": word_del_out,
"word_del_tgt": word_del_targets, "word_del_tgt": word_del_targets,
"word_del_mask": word_predictions.ne(self.pad), "word_del_mask": word_del_masks,
} }
def forward_encoder(self, encoder_inputs): def forward_encoder(self, encoder_inputs):
...@@ -269,248 +407,123 @@ class LevenshteinTransformerModel(TracingTransformerModel): ...@@ -269,248 +407,123 @@ class LevenshteinTransformerModel(TracingTransformerModel):
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
): ):
output_tokens = decoder_out[0] output_tokens = decoder_out.output_tokens
output_scores = decoder_out[1] output_scores = decoder_out.output_scores
attn = decoder_out[2] attn = decoder_out.attn
if max_ratio is not None and encoder_out[1] is not None:
max_lengths = ((~encoder_out[1]).sum(1) * max_ratio).clamp(min=10)
bsz = output_tokens.size(0)
if max_ratio is None:
max_lens = torch.zeros_like(output_tokens).fill_(255)
else: else:
max_lengths = torch.zeros(output_tokens.size(0)).fill_(255) if encoder_out.encoder_padding_mask is None:
max_src_len = encoder_out.encoder_out.size(1)
@torch.jit.script src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len)
def del_word( else:
output_tokens, src_lens = (~encoder_out.encoder_padding_mask).sum(1)
output_scores, max_lens = (src_lens * max_ratio).clamp(min=10).long()
attn: Tensor,
word_del_attn: Optional[Tensor], # delete words
word_del_out, # do not delete tokens if it is <s> </s>
can_del_word,
pad_idx: int,
bos_idx: int,
eos_idx: int,
):
# delete words
# do not delete tokens if it is <s> </s>
if can_del_word.sum() != 0: # we cannot delete, skip
word_del_score = F.log_softmax(word_del_out, 2)
word_del_pred = torch.jit.Attribute(word_del_score.max(-1)[1], bool)
in_tokens = output_tokens[can_del_word]
in_scores = output_scores[can_del_word]
# apply deletion to a tensor
in_masks = in_tokens.ne(pad_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill(word_del_pred, max_len)
.sort(1)[1]
)
_tokens = in_tokens.masked_fill(word_del_pred, pad_idx).gather(
1, reordering
)
_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
if word_del_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(word_del_attn)
_reordering = reordering[:, :, None].expand_as(word_del_attn)
_attn = word_del_attn.masked_fill(_mask, 0.0).gather(1, _reordering)
attn = _fill(attn, can_del_word, _attn, 0)
output_tokens = _fill(output_tokens, can_del_word, _tokens, pad_idx)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
return output_tokens, output_scores, attn
@torch.jit.script
def ins_placeholders(
output_tokens,
output_scores,
mask_ins_out,
can_ins_mask,
pad_idx: int,
unk_idx: int,
eos_idx: int,
max_ratio: float,
max_lengths,
):
# insert placeholders
if can_ins_mask.sum() != 0:
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] -= eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
if max_ratio is not None and encoder_out[1] is not None:
mask_ins_pred = torch.min(
mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred)
)
in_tokens = output_tokens[can_ins_mask]
in_scores = output_scores[can_ins_mask]
in_masks = in_tokens.ne(pad_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len)[None, :].long() < out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
torch.zeros(in_tokens.size()[0], out_max_len)
.fill_(pad_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float())
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = torch.zeros_like(out_tokens).to(in_scores)
out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1)
out_scores.scatter_(1, reordering, in_scores[:, 1:])
else:
out_scores = None
output_tokens = _fill(output_tokens, can_ins_mask, out_tokens, pad_idx)
output_scores = _fill(output_scores, can_ins_mask, out_scores, 0)
return output_tokens, output_scores
@torch.jit.script
def ins_words(
output_tokens,
output_scores,
attn: Tensor,
word_ins_attn,
word_ins_out,
can_ins_word,
pad_idx: int,
unk_idx: int,
):
# insert words
if can_ins_word.sum() != 0:
word_ins_scores = F.log_softmax(word_ins_out, 2)
word_ins_pred = word_ins_scores.max(-1)[1]
in_tokens = output_tokens[can_ins_word]
in_scores = output_scores[can_ins_word]
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(
word_ins_masks, word_ins_pred[word_ins_masks].float()
)
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
output_tokens = _fill(output_tokens, can_ins_word, out_tokens, pad_idx)
output_scores = _fill(output_scores, can_ins_word, out_scores, 0)
attn = _fill(attn, can_ins_word, word_ins_attn, 0)
return output_tokens, output_scores, attn
can_del_word = output_tokens.ne(self.pad).sum(1) > 2 can_del_word = output_tokens.ne(self.pad).sum(1) > 2
word_del_out, word_del_attn = self.decoder.forward_word_del( if can_del_word.sum() != 0: # we cannot delete, skip
script_skip_tensor(output_tokens, can_del_word), word_del_out, word_del_attn = self.decoder.forward_word_del(
script_skip_tensor_list(list(encoder_out), can_del_word), _skip(output_tokens, can_del_word),
) _skip_encoder_out(self.encoder, encoder_out, can_del_word)
)
output_tokens, output_scores, attn = del_word( word_del_score = F.log_softmax(word_del_out, 2)
output_tokens, word_del_pred = word_del_score.max(-1)[1].bool()
output_scores,
attn, _tokens, _scores, _attn = _apply_del_words(
word_del_attn, output_tokens[can_del_word],
word_del_out, output_scores[can_del_word],
can_del_word, word_del_attn,
self.pad, word_del_pred,
self.bos, self.pad,
self.eos, self.bos,
) self.eos,
)
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad)
output_scores = _fill(output_scores, can_del_word, _scores, 0)
attn = _fill(attn, can_del_word, _attn, 0.)
# insert placeholders
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens
if can_ins_mask.sum() != 0:
mask_ins_out, _ = self.decoder.forward_mask_ins(
_skip(output_tokens, can_ins_mask),
_skip_encoder_out(self.encoder, encoder_out, can_ins_mask)
)
mask_ins_score = F.log_softmax(mask_ins_out, 2)
if eos_penalty > 0.0:
mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty
mask_ins_pred = mask_ins_score.max(-1)[1]
mask_ins_pred = torch.min(
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred)
)
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lengths _tokens, _scores = _apply_ins_masks(
mask_ins_out, _ = self.decoder.forward_mask_ins( output_tokens[can_ins_mask],
script_skip_tensor(output_tokens, can_ins_mask), output_scores[can_ins_mask],
script_skip_tensor_list(encoder_out, can_ins_mask), mask_ins_pred,
) self.pad,
output_tokens, output_scores = ins_placeholders( self.unk,
output_tokens, self.eos,
output_scores, )
mask_ins_out, output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad)
can_ins_mask, output_scores = _fill(output_scores, can_ins_mask, _scores, 0)
self.pad,
self.unk,
self.eos,
max_ratio,
max_lengths,
)
# insert words
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 can_ins_word = output_tokens.eq(self.unk).sum(1) > 0
word_ins_out, word_ins_attn = self.decoder.forward_word_ins( if can_ins_word.sum() != 0:
script_skip_tensor(output_tokens, can_ins_word), word_ins_out, word_ins_attn = self.decoder.forward_word_ins(
script_skip_tensor_list(encoder_out, can_ins_word), _skip(output_tokens, can_ins_word),
) _skip_encoder_out(self.encoder, encoder_out, can_ins_word)
)
word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1)
word_ins_pred = word_ins_score.max(-1)[1]
_tokens, _scores = _apply_ins_words(
output_tokens[can_ins_word],
output_scores[can_ins_word],
word_ins_pred,
word_ins_score,
self.unk,
)
output_tokens, output_scores, attn = ins_words( output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad)
output_tokens, output_scores = _fill(output_scores, can_ins_word, _scores, 0)
output_scores, attn = _fill(attn, can_ins_word, word_ins_attn, 0.)
attn,
word_ins_attn,
word_ins_out,
can_ins_word,
self.pad,
self.unk,
)
# delete some unnecessary paddings # delete some unnecessary paddings
cut_off = output_tokens.ne(self.pad).sum(1).max() cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
@torch.jit.script output_scores = output_scores[:, :cut_off]
def slice_wrap(x, l): attn = None if attn is None else attn[:, :cut_off, :]
return x[:, :l]
return decoder_out._replace(
@torch.jit.script output_tokens=output_tokens,
def slice_wrap_attn(x, l): output_scores=output_scores,
return x if x.size()[0] == 0 else x[:, :l, :] attn=attn,
output_tokens = slice_wrap(output_tokens, cut_off)
output_scores = slice_wrap(output_scores, cut_off)
attn = slice_wrap(attn, cut_off)
return [output_tokens, output_scores, attn, 0, 0]
def initialize_output_tokens(self, encoder_out, src_tokens):
initial_output_tokens = torch.cat(
[
torch.zeros(src_tokens.size(0), 1).fill_(self.bos),
torch.zeros(src_tokens.size(0), 1).fill_(self.eos),
],
1,
) )
initial_output_scores = torch.zeros_like(initial_output_tokens).to( def initialize_output_tokens(self, encoder_out, src_tokens):
encoder_out[0] initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2)
initial_output_tokens[:, 0] = self.bos
initial_output_tokens[:, 1] = self.eos
initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size()
).type_as(encoder_out.encoder_out)
return DecoderOut(
output_tokens=initial_output_tokens,
output_scores=initial_output_scores,
attn=None,
step=0,
max_step=0,
) )
initial_attn = torch.empty([0])
if getattr(self.decoder.layers[-1], "need_attn", True):
initial_attn = torch.zeros([src_tokens.size(0), 2, src_tokens.size(1)]).to(
initial_output_tokens
)
return [initial_output_tokens, initial_output_scores, initial_attn, 0, 0]
class LevenshteinTransformerDecoder(TransformerDecoder):
class LevenshteinTransformerDecoder(TracingTransformerDecoder):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__( super().__init__(
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
...@@ -524,38 +537,32 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder): ...@@ -524,38 +537,32 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
self.embed_word_del = Embedding(2, self.output_embed_dim, None) self.embed_word_del = Embedding(2, self.output_embed_dim, None)
# del_word, ins_mask, ins_word # del_word, ins_mask, ins_word
self.early_exit = [int(i) for i in args.early_exit.split(",")] self.early_exit = [int(i) for i in args.early_exit.split(',')]
assert len(self.early_exit) == 3 assert len(self.early_exit) == 3
# copy layers for mask-predict/deletion # copy layers for mask-predict/deletion
self.layers_msk = None self.layers_msk = None
if getattr(args, "no_share_maskpredictor", False): if getattr(args, "no_share_maskpredictor", False):
self.layers_msk = nn.ModuleList( self.layers_msk = nn.ModuleList([
[ TransformerDecoderLayer(args, no_encoder_attn)
TransformerDecoderLayer(args, no_encoder_attn) for _ in range(self.early_exit[1])
for _ in range(self.early_exit[1]) ])
]
)
self.layers_del = None self.layers_del = None
if getattr(args, "no_share_discriminator", False): if getattr(args, "no_share_discriminator", False):
self.layers_del = nn.ModuleList( self.layers_del = nn.ModuleList([
[ TransformerDecoderLayer(args, no_encoder_attn)
TransformerDecoderLayer(args, no_encoder_attn) for _ in range(self.early_exit[0])
for _ in range(self.early_exit[0]) ])
]
) if getattr(args, "share_discriminator_maskpredictor", False):
assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator"
self.layers_msk = self.layers_del
def extract_features( def extract_features(
self, self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused
prev_output_tokens,
encoder_out=None,
early_exit=None,
layers=None,
**unused
): ):
""" """
Similar to *forward* but only return features. Similar to *forward* but only return features.
Inputs: Inputs:
prev_output_tokens: Tensor(B, T) prev_output_tokens: Tensor(B, T)
encoder_out: a dictionary of hidden states and masks encoder_out: a dictionary of hidden states and masks
...@@ -574,7 +581,7 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder): ...@@ -574,7 +581,7 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
) )
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens.long()) x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None: if self.project_in_dim is not None:
x = self.project_in_dim(x) x = self.project_in_dim(x)
...@@ -591,11 +598,11 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder): ...@@ -591,11 +598,11 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) decoder_padding_mask = prev_output_tokens.eq(self.padding_idx)
layers = self.layers if layers is None else layers layers = self.layers if layers is None else layers
early_exit = len(layers) if early_exit is None else early_exit early_exit = len(layers) if early_exit is None else early_exit
for _, layer in enumerate(layers[:early_exit]): for _, layer in enumerate(layers[: early_exit]):
x, attn = layer( x, attn = layer(
x, x,
encoder_out[0] if encoder_out is not None else None, encoder_out.encoder_out if encoder_out is not None else None,
encoder_out[1] if encoder_out is not None else None, encoder_out.encoder_padding_mask if encoder_out is not None else None,
self_attn_mask=None, self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask, self_attn_padding_mask=decoder_padding_mask,
) )
...@@ -610,38 +617,26 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder): ...@@ -610,38 +617,26 @@ class LevenshteinTransformerDecoder(TracingTransformerDecoder):
if self.project_out_dim is not None: if self.project_out_dim is not None:
x = self.project_out_dim(x) x = self.project_out_dim(x)
return x, attn, inner_states return x, {"attn": attn, "inner_states": inner_states}
def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused): def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, attn, _ = self.extract_features( features, extra = self.extract_features(
prev_output_tokens, prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused
encoder_out=encoder_out,
early_exit=self.early_exit[1],
layers=self.layers_msk,
**unused
) )
features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
return F.linear(features_cat, self.embed_mask_ins.weight), attn return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn']
def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused): def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused):
features, attn, _ = self.extract_features( features, extra = self.extract_features(
prev_output_tokens, prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused
encoder_out=encoder_out,
early_exit=self.early_exit[2],
layers=self.layers,
**unused
) )
return self.output_layer(features), attn return self.output_layer(features), extra['attn']
def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused): def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused):
features, attn, _ = self.extract_features( features, extra = self.extract_features(
prev_output_tokens, prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused
encoder_out=encoder_out,
early_exit=self.early_exit[0],
layers=self.layers_del,
**unused
) )
return F.linear(features, self.embed_word_del.weight), attn return F.linear(features, self.embed_word_del.weight), extra['attn']
@register_model_architecture("levenshtein_transformer", "levenshtein_transformer") @register_model_architecture("levenshtein_transformer", "levenshtein_transformer")
...@@ -671,7 +666,7 @@ def base_architecture(args): ...@@ -671,7 +666,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False args, "share_decoder_input_output_embed", False
) )
args.share_all_embeddings = getattr(args, "share_all_embeddings", True) args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
...@@ -686,6 +681,8 @@ def base_architecture(args): ...@@ -686,6 +681,8 @@ def base_architecture(args):
args.early_exit = getattr(args, "early_exit", "6,6,6") args.early_exit = getattr(args, "early_exit", "6,6,6")
args.no_share_discriminator = getattr(args, "no_share_discriminator", False) args.no_share_discriminator = getattr(args, "no_share_discriminator", False)
args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False) args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False)
args.share_discriminator_maskpredictor = getattr(args, "share_discriminator_maskpredictor", False)
args.no_share_last_layer = getattr(args, "no_share_last_layer", False)
@register_model_architecture( @register_model_architecture(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict, List from typing import List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask): ...@@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask):
return res return res
@torch.jit.script
def script_skip_tensor_dict(x: Dict[str, Tensor], mask):
outputs = {}
for s, t in x.items():
outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask]
return outputs
def skip_tensors(x, mask):
"""
Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors.
"""
if isinstance(x, int):
return x
if x is None:
return None
if isinstance(x, torch.Tensor):
if x.size(0) == mask.size(0):
return x[mask]
elif x.size(1) == mask.size(0):
return x[:, mask]
if isinstance(x, list):
return [skip_tensors(x_i, mask) for x_i in x]
if isinstance(x, dict):
return {k: skip_tensors(v, mask) for k, v in x.items()}
raise NotImplementedError
@torch.jit.script @torch.jit.script
def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
""" """
...@@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): ...@@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):
@torch.jit.script @torch.jit.script
def fill_tensors(x, mask, y, padding_idx: int): def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor:
return x if x is not None else y
@torch.jit.script
def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]:
""" """
Filling tensor x with y at masked positions (dim=0). Filling tensor x with y at masked positions (dim=0).
""" """
if x is None or x.size()[0] == 0: if x is None or x.size()[0] == 0 or y is None:
return torch.empty([0]) return x
assert x.dim() == y.dim() and mask.size(0) == x.size(0) assert x.dim() == y.dim() and mask.size(0) == x.size(0)
assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2))
...@@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int): ...@@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int):
else: else:
x[mask] = y x[mask] = y
return x return x
def _apply_ins_masks(
in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx
):
in_masks = in_tokens.ne(padding_idx)
in_lengths = in_masks.sum(1)
# HACK: hacky way to shift all the paddings to eos first.
in_tokens.masked_fill_(~in_masks, eos_idx)
mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0)
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
< out_lengths[:, None]
)
reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1)
out_tokens = (
in_tokens.new_zeros(in_tokens.size(0), out_max_len)
.fill_(padding_idx)
.masked_fill_(out_masks, unk_idx)
)
out_tokens[:, 0] = in_tokens[:, 0]
out_tokens.scatter_(1, reordering, in_tokens[:, 1:])
out_scores = None
if in_scores is not None:
in_scores.masked_fill_(~in_masks, 0)
out_scores = in_scores.new_zeros(*out_tokens.size())
out_scores[:, 0] = in_scores[:, 0]
out_scores.scatter_(1, reordering, in_scores[:, 1:])
return out_tokens, out_scores
def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx):
word_ins_masks = in_tokens.eq(unk_idx)
out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks])
if in_scores is not None:
out_scores = in_scores.masked_scatter(
word_ins_masks, word_ins_scores[word_ins_masks]
)
else:
out_scores = None
return out_tokens, out_scores
def _apply_del_words(
in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx
):
# apply deletion to a tensor
in_masks = in_tokens.ne(padding_idx)
bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx)
max_len = in_tokens.size(1)
word_del_pred.masked_fill_(~in_masks, 1)
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering)
out_scores = None
if in_scores is not None:
out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering)
out_attn = None
if in_attn is not None:
_mask = word_del_pred[:, :, None].expand_as(in_attn)
_reordering = reordering[:, :, None].expand_as(in_attn)
out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering)
return out_tokens, out_scores, out_attn
...@@ -3,11 +3,18 @@ ...@@ -3,11 +3,18 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import math
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip from fairseq.models.levenshtein_transformer import (
from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words _skip,
_apply_ins_masks,
_apply_ins_words,
_apply_del_words,
)
from fairseq.models.model_utils import fill_tensors as _fill
class BasicEnsembleModel(torch.nn.Module): class BasicEnsembleModel(torch.nn.Module):
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import ( from fairseq.models.transformer import (
Embedding, Embedding,
...@@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens): ...@@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens):
@register_model("nonautoregressive_transformer") @register_model("nonautoregressive_transformer")
class NATransformerModel(TransformerModel): class NATransformerModel(TransformerModel):
def __init__(self, encoder, decoder): def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(args, encoder, decoder)
self.tgt_dict = decoder.dictionary self.tgt_dict = decoder.dictionary
self.bos = decoder.dictionary.bos() self.bos = decoder.dictionary.bos()
self.eos = decoder.dictionary.eos() self.eos = decoder.dictionary.eos()
...@@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel): ...@@ -112,9 +114,9 @@ class NATransformerModel(TransformerModel):
return self.encoder(*encoder_inputs) return self.encoder(*encoder_inputs)
def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):
step = decoder_out["step"] step = decoder_out.step
output_tokens = decoder_out["output_tokens"] output_tokens = decoder_out.output_tokens
output_scores = decoder_out["output_scores"] output_scores = decoder_out.output_scores
# execute the decoder # execute the decoder
output_masks = output_tokens.ne(self.pad) output_masks = output_tokens.ne(self.pad)
...@@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel): ...@@ -127,12 +129,16 @@ class NATransformerModel(TransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks]) output_scores.masked_scatter_(output_masks, _scores[output_masks])
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None} return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)
def initialize_output_tokens(self, encoder_out, src_tokens): def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction # length prediction
_, length_tgt = self.decoder.forward_length_prediction(encoder_out) _, length_tgt = self.decoder.forward_length_prediction(encoder_out)
max_length = length_tgt.max() max_length = length_tgt.clamp_(min=2).max()
idx_length = utils.new_arange(src_tokens, max_length) idx_length = utils.new_arange(src_tokens, max_length)
initial_output_tokens = src_tokens.new_zeros( initial_output_tokens = src_tokens.new_zeros(
...@@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel): ...@@ -146,13 +152,15 @@ class NATransformerModel(TransformerModel):
initial_output_scores = initial_output_tokens.new_zeros( initial_output_scores = initial_output_tokens.new_zeros(
*initial_output_tokens.size() *initial_output_tokens.size()
).type_as(encoder_out["encoder_out"]) ).type_as(encoder_out.encoder_out)
return { return DecoderOut(
"output_tokens": initial_output_tokens, output_tokens=initial_output_tokens,
"output_scores": initial_output_scores, output_scores=initial_output_scores,
"attn": None attn=None,
} step=0,
max_step=0,
)
class NATransformerDecoder(TransformerDecoder): class NATransformerDecoder(TransformerDecoder):
...@@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder): ...@@ -220,8 +228,8 @@ class NATransformerDecoder(TransformerDecoder):
""" """
# embedding # embedding
if embedding_copy: if embedding_copy:
src_embd = encoder_out["encoder_embedding"] src_embd = encoder_out.encoder_embedding
src_mask = encoder_out["encoder_padding_mask"] src_mask = encoder_out.encoder_padding_mask
src_mask = ( src_mask = (
~src_mask ~src_mask
if src_mask is not None if src_mask is not None
...@@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder): ...@@ -253,10 +261,8 @@ class NATransformerDecoder(TransformerDecoder):
x, attn = layer( x, attn = layer(
x, x,
encoder_out["encoder_out"] if encoder_out is not None else None, encoder_out.encoder_out if encoder_out is not None else None,
encoder_out["encoder_padding_mask"] encoder_out.encoder_padding_mask if encoder_out is not None else None,
if encoder_out is not None
else None,
self_attn_mask=None, self_attn_mask=None,
self_attn_padding_mask=decoder_padding_mask, self_attn_padding_mask=decoder_padding_mask,
) )
...@@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder): ...@@ -311,8 +317,8 @@ class NATransformerDecoder(TransformerDecoder):
return copied_embedding return copied_embedding
def forward_length_prediction(self, encoder_out, tgt_tokens=None): def forward_length_prediction(self, encoder_out, tgt_tokens=None):
enc_feats = encoder_out["encoder_out"] # T x B x C enc_feats = encoder_out.encoder_out # T x B x C
src_masks = encoder_out["encoder_padding_mask"] # B x T or None src_masks = encoder_out.encoder_padding_mask # B x T or None
if self.pred_length_offset: if self.pred_length_offset:
if src_masks is None: if src_masks is None:
......
...@@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder): ...@@ -348,7 +348,7 @@ class RobertaEncoder(FairseqDecoder):
- a dictionary of additional data, where 'inner_states' - a dictionary of additional data, where 'inner_states'
is a list of hidden states. is a list of hidden states.
""" """
x, extra = self.extract_features(src_tokens, return_all_hiddens) x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens)
if not features_only: if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens) x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra return x, extra
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, Linear, base_architecture
from fairseq.modules import (
AdaptiveSoftmax,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
TransformerDecoderLayer,
TransformerEncoderLayer,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('tracing_transformer')
class TracingTransformerModel(FairseqEncoderDecoderModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
@classmethod
def hub_models(cls):
# fmt: off
return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz',
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz',
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz',
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz',
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz',
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz',
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz',
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
}
# fmt: on
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.supports_align_args = True
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return cls(encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TracingTransformerEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TracingTransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, 'no_cross_attention', False),
)
class TracingTransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward_embedding(self, src_tokens):
# embed tokens and positions
embed = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.layer_wise_attention:
return_all_hiddens = True
x, encoder_embedding = self.forward_embedding(src_tokens)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
encoder_states = [] if return_all_hiddens else None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
if return_all_hiddens:
encoder_states[-1] = x
if encoder_states is not None:
return x, encoder_padding_mask, encoder_embedding, encoder_states
else:
return x, encoder_padding_mask, encoder_embedding
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
# 0: encoder_out
# 1: encoder_padding_mask
# 2: encoder_states
if encoder_out[0] is not None:
encoder_out[0] = \
encoder_out[0].index_select(1, new_order)
if encoder_out[1] is not None:
encoder_out[1] = \
encoder_out[1].index_select(0, new_order)
if len(encoder_out) == 3 and encoder_out[2] is not None:
for idx, state in enumerate(encoder_out[2]):
encoder_out[2][idx] = state.index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i))
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TracingTransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, self.padding_idx,
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
features_only=False,
**extra_args,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
full_context_alignment=False,
alignment_layer=None,
alignment_heads=None,
**unused,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = len(self.layers) - 1
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None
# decoder layers
attn = None
inner_states = [x]
for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out[3][idx]
else:
encoder_state = encoder_out[0]
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn = layer(
x,
encoder_state
if encoder_state is not None else None,
encoder_out[1]
if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {'attn': attn, 'inner_states': inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
return F.linear(features, self.embed_tokens.weight)
else:
return F.linear(features, self.embed_out)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'encoder_attn_layer_norm',
'2': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
if k in state_dict:
state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
del state_dict[k]
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import namedtuple
import math import math
import torch import torch
...@@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel): ...@@ -279,6 +280,14 @@ class TransformerAlignModel(TransformerModel):
return decoder_out return decoder_out
EncoderOut = namedtuple('TransformerEncoderOut', [
'encoder_out', # T x B x C
'encoder_padding_mask', # B x T
'encoder_embedding', # B x T x C
'encoder_states', # List[T x B x C]
])
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
""" """
Transformer encoder consisting of *args.encoder_layers* layers. Each layer Transformer encoder consisting of *args.encoder_layers* layers. Each layer
...@@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder): ...@@ -348,11 +357,13 @@ class TransformerEncoder(FairseqEncoder):
intermediate hidden states (default: False). intermediate hidden states (default: False).
Returns: Returns:
dict: namedtuple:
- **encoder_out** (Tensor): the last encoder layer's output of - **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)` shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of - **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)` padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate - **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`. hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True. Only populated if *return_all_hiddens* is True.
...@@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder): ...@@ -386,12 +397,12 @@ class TransformerEncoder(FairseqEncoder):
if return_all_hiddens: if return_all_hiddens:
encoder_states[-1] = x encoder_states[-1] = x
return { return EncoderOut(
'encoder_out': x, # T x B x C encoder_out=x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T encoder_padding_mask=encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C encoder_embedding=encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C] encoder_states=encoder_states, # List[T x B x C]
} )
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
""" """
...@@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder): ...@@ -404,15 +415,21 @@ class TransformerEncoder(FairseqEncoder):
Returns: Returns:
*encoder_out* rearranged according to *new_order* *encoder_out* rearranged according to *new_order*
""" """
if encoder_out['encoder_out'] is not None: if encoder_out.encoder_out is not None:
encoder_out['encoder_out'] = \ encoder_out = encoder_out._replace(
encoder_out['encoder_out'].index_select(1, new_order) encoder_out=encoder_out.encoder_out.index_select(1, new_order)
if encoder_out['encoder_padding_mask'] is not None: )
encoder_out['encoder_padding_mask'] = \ if encoder_out.encoder_padding_mask is not None:
encoder_out['encoder_padding_mask'].index_select(0, new_order) encoder_out = encoder_out._replace(
if encoder_out.get('encoder_states', None) is not None: encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order)
for idx, state in enumerate(encoder_out['encoder_states']): )
encoder_out['encoder_states'][idx] = state.index_select(1, new_order) if encoder_out.encoder_embedding is not None:
encoder_out = encoder_out._replace(
encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order)
)
if encoder_out.encoder_states is not None:
for idx, state in enumerate(encoder_out.encoder_states):
encoder_out.encoder_states[idx] = state.index_select(1, new_order)
return encoder_out return encoder_out
def max_positions(self): def max_positions(self):
...@@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -532,13 +549,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out=None, encoder_out=None,
incremental_state=None, incremental_state=None,
features_only=False, features_only=False,
**extra_args, **extra_args
): ):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing `(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for encoder_out (optional): output from the encoder, used for
encoder-side attention encoder-side attention
incremental_state (dict): dictionary used for storing state during incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding` :ref:`Incremental decoding`
...@@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -551,7 +568,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
- a dictionary with any model-specific outputs - a dictionary with any model-specific outputs
""" """
x, extra = self.extract_features( x, extra = self.extract_features(
prev_output_tokens, encoder_out, incremental_state, **extra_args, prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
**extra_args
) )
if not features_only: if not features_only:
x = self.output_layer(x) x = self.output_layer(x)
...@@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -628,9 +648,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_state = None encoder_state = None
if encoder_out is not None: if encoder_out is not None:
if self.layer_wise_attention: if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx] encoder_state = encoder_out.encoder_states[idx]
else: else:
encoder_state = encoder_out['encoder_out'] encoder_state = encoder_out.encoder_out
if incremental_state is None and not full_context_alignment: if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x) self_attn_mask = self.buffered_future_mask(x)
...@@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -643,7 +663,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x, layer_attn = layer( x, layer_attn = layer(
x, x,
encoder_state, encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None, encoder_out.encoder_padding_mask if encoder_out is not None else None,
incremental_state, incremental_state,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
......
...@@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module): ...@@ -26,16 +26,15 @@ class MeanPoolGatingNetwork(torch.nn.Module):
def forward(self, encoder_out): def forward(self, encoder_out):
if not ( if not (
isinstance(encoder_out, dict) hasattr(encoder_out, 'encoder_out')
and 'encoder_out' in encoder_out and hasattr(encoder_out, 'encoder_padding_mask')
and 'encoder_padding_mask' in encoder_out and encoder_out.encoder_out.size(2) == self.embed_dim
and encoder_out['encoder_out'].size(2) == self.embed_dim
): ):
raise ValueError('Unexpected format for encoder_out') raise ValueError('Unexpected format for encoder_out')
# mean pooling over time # mean pooling over time
encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T encoder_padding_mask = encoder_out.encoder_padding_mask # B x T
encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0 encoder_out[encoder_padding_mask] = 0
......
...@@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase): ...@@ -197,51 +197,90 @@ class TestTranslation(unittest.TestCase):
]) ])
generate_main(data_dir) generate_main(data_dir)
def test_cmlm_transformer(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'cmlm_transformer', [
'--apply-bert-init',
'--criterion', 'nat_loss',
'--noise', 'full_mask',
'--pred-length-offset',
'--length-loss-factor', '0.1'
], task='translation_lev')
generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_levenshtein_transformer(self): def test_levenshtein_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir: with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'levenshtein_transformer', [ train_translation_model(data_dir, 'levenshtein_transformer', [
'--apply-bert-init', '--early-exit', '6,6,6', '--apply-bert-init', '--early-exit', '6,6,6',
'--criterion', 'nat_loss' '--criterion', 'nat_loss'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_nonautoregressive_transformer(self): def test_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'nonautoregressive_transformer', [ train_translation_model(data_dir, 'nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion', '--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--pred-length-offset', 'nat_loss', '--noise', 'full_mask', '--pred-length-offset',
'--length-loss-factor', '0.1' '--length-loss-factor', '0.1'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_iterative_nonautoregressive_transformer(self): def test_iterative_nonautoregressive_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir: with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [ train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [
'--apply-bert-init', '--src-embedding-copy', '--criterion', '--apply-bert-init', '--src-embedding-copy', '--criterion',
'nat_loss', '--noise', 'full_mask', '--stochastic-approx', 'nat_loss', '--noise', 'full_mask', '--stochastic-approx',
'--dae-ratio', '0.5', '--train-step', '3' '--dae-ratio', '0.5', '--train-step', '3'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_insertion_transformer(self): def test_insertion_transformer(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir: with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir) preprocess_translation_data(data_dir, ['--joined-dictionary'])
train_translation_model(data_dir, 'insertion_transformer', [ train_translation_model(data_dir, 'insertion_transformer', [
'--apply-bert-init', '--criterion', 'nat_loss', '--noise', '--apply-bert-init', '--criterion', 'nat_loss', '--noise',
'random_mask' 'random_mask'
], task='translation_lev') ], task='translation_lev')
generate_main(data_dir) generate_main(data_dir, [
'--task', 'translation_lev',
'--iter-decode-max-iter', '9',
'--iter-decode-eos-penalty', '0',
'--print-step',
])
def test_mixture_of_experts(self): def test_mixture_of_experts(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment