Commit 49c9895b authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in return future mask

parent 689e0b24
......@@ -11,14 +11,17 @@ import torch.nn.functional as F
from . import TransformerDecoderLayer, LayerNorm
from .transformer_encoder import relative_position_bucket
def fill_with_neg_inf(t):
return t.fill_(float("-inf"))
def bulid_future_mask(seq_len):
return torch.triu(
fill_with_neg_inf(torch.zeros([seq_len, seq_len])), 1
)
class TransformerDecoder(nn.Module):
def __init__(
self,
......@@ -66,7 +69,7 @@ class TransformerDecoder(nn.Module):
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,
)
for _ in range(decoder_layers)
]
......@@ -77,7 +80,8 @@ class TransformerDecoder(nn.Module):
assert rel_pos_bins % 2 == 0
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads)
self.relative_attention_bias = nn.Embedding(
self.rel_pos_bins, self.attention_heads)
seq_len = self.max_seq_len
context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long)[None, :]
......@@ -98,7 +102,7 @@ class TransformerDecoder(nn.Module):
values = F.embedding(rp_bucket, self.relative_attention_bias.weight)
values = values.permute([2, 0, 1])
return values.contiguous()
def get_future_mask(self, x, attn_mask):
if not self.auto_regressive:
return attn_mask
......@@ -108,9 +112,12 @@ class TransformerDecoder(nn.Module):
self._future_mask = self._future_mask.type_as(x)
if attn_mask is None:
ret = self._future_mask[:x.size(1), :x.size(1)]
ret = ret.contiguous().unsqueeze(0).repeat(x.size(0)*self.attention_heads, 1, 1)
ret = ret.contiguous().unsqueeze(0).repeat(
x.size(0)*self.attention_heads, 1, 1)
return ret
else:
assert list(attn_mask.size()) == [x.size(0) * self.attention_heads, x.size(1), x.size(1)]
assert list(attn_mask.size()) == [x.size(
0) * self.attention_heads, x.size(1), x.size(1)]
return attn_mask + self._future_mask[:x.size(1), :x.size(1)]
def forward(
......@@ -122,16 +129,17 @@ class TransformerDecoder(nn.Module):
attn_mask: Optional[torch.Tensor] = None,
encoder_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
seq_len = emb.size(1)
seq_len = emb.size(1)
x = self.emb_layer_norm(emb)
x = F.dropout(x, p=self.emb_dropout, training=self.training)
# account for padding while computing the representation
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.size(0), 1, 1) if self.rel_pos else None
rel_pos_bias = self.get_rel_pos_bias(x).repeat(
x.size(0), 1, 1) if self.rel_pos else None
if attn_mask is None:
attn_mask = rel_pos_bias
......@@ -150,12 +158,12 @@ class TransformerDecoder(nn.Module):
)
attn_mask = attn_mask.view(-1, seq_len, seq_len)
padding_mask = None
for layer in self.layers:
for layer in self.layers:
x = layer(x, encoder_out=encoder_out, padding_mask=padding_mask, attn_bias=attn_mask,
encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask)
if self.final_layer_norm != None:
encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask)
if self.final_layer_norm is not None:
x = self.final_layer_norm(x)
return x
......@@ -157,7 +157,7 @@ class TransformerEncoder(nn.Module):
for layer in self.layers:
x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask)
if self.final_layer_norm != None:
if self.final_layer_norm is not None:
x = self.final_layer_norm(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment