Commit cce92bdd authored by Jiatao Gu's avatar Jiatao Gu Committed by Facebook Github Bot
Browse files

add new_arange function + FIX BUGS of returning attn values

Summary:
Implementation of Levenshtein Transformer paper.
Add a new helper function "new_arange" to create arange tensor easily.
Fix bugs of returning attn values for NAT models
Delete files which are not necessary or experimental.

Reviewed By: kahne

Differential Revision: D17652009

fbshipit-source-id: 436bbb5d45de2f8067003232de4f2bd51e87719c
parent c4893ca6
......@@ -10,7 +10,7 @@ Ghazvininejad, Marjan, et al.
arXiv preprint arXiv:1904.09324 (2019).
"""
import torch
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel
......@@ -20,10 +20,7 @@ def _skeptical_unmasking(output_scores, output_masks, p):
boundary_len = (
(output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p
).long()
skeptical_mask = (
torch.arange(output_masks.size(1), device=output_masks.device)[None, :]
< boundary_len
)
skeptical_mask = new_arange(output_masks) < boundary_len
return skeptical_mask.scatter(1, sorted_index, skeptical_mask)
......
......@@ -6,7 +6,8 @@
import numpy as np
import torch
import torch.nn.functional as F
from fairseq import libnat
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import (
LevenshteinTransformerDecoder,
......@@ -51,13 +52,6 @@ neg_scorer = NegativeDistanceScore()
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
B = in_tokens.size(0)
T = in_tokens.size(1)
V = vocab_size
......@@ -102,8 +96,7 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi
word_ins_scores.masked_fill_(padding_masks, 0.0)
word_ins_pred.masked_fill_(padding_masks, padding_idx)
in_coords = torch.arange(in_tokens.size(1), device=in_tokens.device)
in_coords = in_coords.unsqueeze(0).repeat(in_tokens.size(0), 1).type_as(in_scores)
in_coords = new_arange(in_tokens).type_as(in_scores)
# shift all padding predictions to infinite
out_coords = (in_coords[:, 1:] - 0.5).masked_fill(
......@@ -188,7 +181,7 @@ class InsertionTransformerModel(LevenshteinTransformerModel):
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
return {"output_tokens": output_tokens, "output_scores": output_scores}
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
......
......@@ -5,7 +5,8 @@
import torch
import torch.nn.functional as F
from fairseq import libnat
from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip
from fairseq.models.transformer import (
......@@ -18,13 +19,6 @@ from fairseq.modules.transformer_sentence_encoder import init_bert_params
def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
......@@ -67,13 +61,6 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx):
def _get_del_targets(in_tokens, out_tokens, padding_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
out_seq_len = out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
......@@ -100,13 +87,6 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx):
def _get_del_ins_targets(in_tokens, out_tokens, padding_idx):
try:
from fairseq import libnat
except ImportError as e:
import sys
sys.stderr.write('ERROR: missing libnat. run `pip install --editable .`\n')
raise e
in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1)
with torch.cuda.device_of(in_tokens):
......@@ -156,7 +136,7 @@ def _apply_ins_masks(
out_lengths = in_lengths + mask_ins_pred.sum(1)
out_max_len = out_lengths.max()
out_masks = (
torch.arange(out_max_len, device=out_lengths.device)[None, :]
new_arange(out_lengths, out_max_len)[None, :]
< out_lengths[:, None]
)
......@@ -205,9 +185,7 @@ def _apply_del_words(
word_del_pred.masked_fill_(bos_eos_masks, 0)
reordering = (
torch.arange(max_len, device=in_tokens.device)[None, :]
.expand_as(in_tokens)
.contiguous()
new_arange(in_tokens)
.masked_fill_(word_del_pred, max_len)
.sort(1)[1]
)
......
......@@ -10,11 +10,9 @@ from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
Embedding,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerModel,
)
from fairseq.modules import MultiheadAttention
from fairseq.modules.transformer_sentence_encoder import init_bert_params
......@@ -35,45 +33,11 @@ def _argmax(x, dim):
return (x == x.max(dim, keepdim=True)[0]).type_as(x)
def _dynamic_programming(tokens, scores):
N, B, T = tokens.size()
cum_scores = scores[:, :, 0].clone() # N x B
cum_choice = tokens.new_zeros(B, T)
# forward
for t in range(T - 1):
score, choice = cum_scores.max(0)
cum_choice[:, t] = choice
cum_scores[0] = score + scores[0, :, t + 1]
cum_scores[1:] = cum_scores[:-1] + scores[1:, :, t + 1]
# back-tracking
end_score, end_choice = cum_scores.max(0)
cum_choice[:, T - 1] = end_choice
for t in range(T - 2, -1, -1):
is_start = (cum_choice[:, t + 1] == 0).type_as(cum_choice)
cum_choice[:, t] = (cum_choice[:, t + 1] - 1) * ~is_start + cum_choice[
:, t
] * is_start
# finalize the prediction
tokens = tokens.gather(0, cum_choice.unsqueeze(0)).squeeze(0)
scores = scores.gather(0, cum_choice.unsqueeze(0)).squeeze(0)
return scores, tokens
def _beam_search(tokens, scores, W=None):
N, B, T = tokens.size()
if (W is None) or (W > N):
W = N
def _uniform_assignment(src_lens, trg_lens):
max_trg_len = trg_lens.max()
steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size
# max_trg_len
index_t = torch.arange(max_trg_len, device=trg_lens.device).float()
index_t = utils.new_arange(trg_lens, max_trg_len).float()
index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len
index_t = torch.round(index_t).long().detach()
return index_t
......@@ -108,16 +72,6 @@ class NATransformerModel(TransformerModel):
parser.add_argument("--length-loss-factor", type=float,
help="weights on the length prediction loss")
# n-gram predictor
parser.add_argument(
"--ngram-predictor",
nargs="?",
const=4,
default=1,
type=int,
help="adding an additional n-gram predictor.",
)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
decoder = NATransformerDecoder(args, tgt_dict, embed_tokens)
......@@ -173,13 +127,13 @@ class NATransformerModel(TransformerModel):
output_tokens.masked_scatter_(output_masks, _tokens[output_masks])
output_scores.masked_scatter_(output_masks, _scores[output_masks])
return {"output_tokens": output_tokens, "output_scores": output_scores}
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
def initialize_output_tokens(self, encoder_out, src_tokens):
# length prediction
_, length_tgt = self.decoder.forward_length_prediction(encoder_out)
max_length = length_tgt.max()
idx_length = torch.arange(max_length, device=src_tokens.device)
idx_length = utils.new_arange(src_tokens, max_length)
initial_output_tokens = src_tokens.new_zeros(
src_tokens.size(0), max_length
......@@ -197,6 +151,7 @@ class NATransformerModel(TransformerModel):
return {
"output_tokens": initial_output_tokens,
"output_scores": initial_output_scores,
"attn": None
}
......@@ -218,11 +173,6 @@ class NATransformerDecoder(TransformerDecoder):
self.src_embedding_copy = getattr(args, "src_embedding_copy", False)
self.embed_length = Embedding(256, self.encoder_embed_dim, None)
self.ngram_predictor = getattr(args, "ngram_predictor", 1)
self.ngram_layer = (
None if (self.ngram_predictor == 1) else NgramDecoderLayer(args, True)
)
def forward(
self,
prev_output_tokens,
......@@ -240,25 +190,12 @@ class NATransformerDecoder(TransformerDecoder):
)
if tgt_tokens is not None:
if self.ngram_layer is None:
word_ins_mask = tgt_tokens.ne(self.padding_idx)
word_ins_tgt = tgt_tokens
else:
context_embeds, context_masks = self.forward_ngram_context(tgt_tokens)
features = self.ngram_layer(features, context_embeds=context_embeds)
word_ins_tgt = tgt_tokens[:, :, None].repeat(1, 1, self.ngram_predictor)
word_ins_mask = word_ins_tgt.ne(self.padding_idx) & context_masks
return self.output_layer(features), word_ins_tgt, word_ins_mask
else:
if self.ngram_layer is None:
return F.log_softmax(self.output_layer(features), -1).max(-1)
else:
# inner iterations
return self.forward_ngram_decoding(
features, prev_output_tokens.eq(self.padding_idx), decoding_format
)
def extract_features(
self,
......@@ -336,82 +273,6 @@ class NATransformerDecoder(TransformerDecoder):
return x, {"attn": attn, "inner_states": inner_states}
def forward_ngram_context(self, tgt_tokens):
tgt_embeds = self.forward_embedding(tgt_tokens)
n_contexts = self.ngram_predictor - 1
# shifting the embeddings
# context_embeds: N x B x T x C
# context_masks: B x T x N
context_embeds = tgt_embeds.new_zeros(n_contexts, *tgt_embeds.size())
context_masks = tgt_embeds.new_ones(
*tgt_embeds.size()[:2], self.ngram_predictor
).bool()
for k in range(n_contexts):
context_embeds[k, :, k + 1:] = tgt_embeds[:, : -k - 1]
context_masks[:, : k + 1, k + 1] = 0
return context_embeds, context_masks
def forward_ngram_decoding(self, features, padding_mask=None, decoding_format=None):
context_embeds = None
scores, tokens = [], []
ensemble_score = None
ensemble_index = None
if decoding_format is None:
decoding_format = "ensemble"
for k in range(self.ngram_predictor):
ngram_out = self.ngram_layer(
features, context_embeds=context_embeds, incremental=True
)
ngram_scores = F.log_softmax(self.output_layer(ngram_out), -1)
max_score, max_token = ngram_scores.max(-1)
if decoding_format == "vote":
ngram_scores = _argmax(ngram_scores, -1)
if ensemble_score is None:
ensemble_score = ngram_scores
ensemble_index = ensemble_score.new_ones(*ensemble_score.size()[:2])
else:
ensemble_index[:, k:] = ensemble_index[:, k:] + 1
ensemble_score = ensemble_score + ngram_scores.masked_fill_(
(ensemble_index < k)
.unsqueeze(2)
.repeat(1, 1, ensemble_score.size(2)),
0,
)
max_score[:, :k] = float("-inf")
if decoding_format == "unigram":
break
scores.append(max_score.masked_fill_(padding_mask, 0))
tokens.append(max_token.masked_fill_(padding_mask, self.padding_idx))
# context_embeds: N x B x T x C
if context_embeds is None:
context_embeds = self.forward_embedding(max_token).unsqueeze(0)
else:
context_embeds = torch.cat(
[self.forward_embedding(max_token).unsqueeze(0), context_embeds], 0
)
context_embeds[:, :, 1:] = context_embeds[:, :, :-1]
if decoding_format != "dp":
ensemble_score = ensemble_score / ensemble_index.unsqueeze(2)
return ensemble_score.max(-1)
else:
tokens = torch.cat([t.unsqueeze(0) for t in tokens], 0)
scores = torch.cat([s.unsqueeze(0) for s in scores], 0)
return _dynamic_programming(tokens, scores)
def forward_embedding(self, prev_output_tokens, states=None):
# embed positions
positions = (
......@@ -489,101 +350,6 @@ class NATransformerDecoder(TransformerDecoder):
return length_out, length_tgt
class NgramDecoderLayer(TransformerDecoderLayer):
"""
N-gram Decoder Layer:
This module can be pluged in the last layer of any Non-autoregressive Model's
It provides an alternative way to capture local n-gram information by running the block multiple times.
"""
def __init__(self, args, no_encoder_attn=False):
super(NgramDecoderLayer, self).__init__(args, no_encoder_attn=no_encoder_attn)
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=1, # maybe n-gram does not need too many heads.
dropout=args.attention_dropout,
self_attention=False,
encoder_decoder_attention=True,
)
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
context_embeds=None,
incremental=False,
):
# x: T x B x C
# context_embeds: N x T x B x C
T, B, C = x.size()
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x = x.contiguous().view(1, T * B, C).contiguous()
if context_embeds is not None:
N = context_embeds.size(0)
context_embeds = context_embeds.view(N, T * B, C).contiguous()
if not incremental:
assert context_embeds is not None, "we need context for training"
# attn_weights: (n_head x T x B) x 1 x N
# v: (n_head x T x B) x N x (dim / n_head)
# -- move the attention computation outside --
attn_weights, values = self.self_attn(
query=x, key=context_embeds, value=context_embeds, before_softmax=True
)
attn_weights = attn_weights.repeat(1, N, 1)
attn_masks = attn_weights.new_ones(N, N).triu_(1).bool()
attn_masks = attn_masks.unsqueeze(0).repeat(attn_weights.size(0), 1, 1)
attn_weights = attn_weights.masked_fill(attn_masks, float("-inf"))
attn_weights = utils.softmax(attn_weights, dim=-1).type_as(attn_weights)
attn_weights = F.dropout(
attn_weights, p=self.self_attn.dropout, training=self.training
)
# (n_head x T x B) x N x (dim / n_head)
attn = torch.bmm(attn_weights, values)
attn = attn.transpose(0, 1).contiguous()
attn = attn.view(N, T * B, C).contiguous()
attn = attn.transpose(1, 0).contiguous()
attn = attn.view(T, B, N, C)
residual = residual.unsqueeze(2)
x = self.self_attn.out_proj(attn)
x = F.dropout(x, p=self.dropout, training=self.training)
x = torch.cat([residual, residual + x], 2)
else:
if context_embeds is None:
x = residual
else:
x, _ = self.self_attn(query=x, key=context_embeds, value=context_embeds)
x = x.view(T, B, C)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
raise NotImplementedError
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x
@register_model_architecture(
"nonautoregressive_transformer", "nonautoregressive_transformer"
)
......@@ -630,7 +396,6 @@ def base_architecture(args):
args.pred_length_offset = getattr(args, "pred_length_offset", False)
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
args.ngram_predictor = getattr(args, "ngram_predictor", 1)
@register_model_architecture(
......
......@@ -412,3 +412,13 @@ def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos):
for tgt_idx, src_idx in zip(tgt_valid, src_indices):
alignment.append((src_token_to_word[src_idx.item()] - 1, tgt_token_to_word[tgt_idx.item()] - 1))
return alignment
def new_arange(x, *size):
"""
Return a Tensor of `size` filled with a range function on the device of x.
If size is empty, using the size of the variable x.
"""
if len(size) == 0:
size = x.size()
return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment