Commit 49c9895b authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in return future mask

parent 689e0b24
...@@ -11,14 +11,17 @@ import torch.nn.functional as F ...@@ -11,14 +11,17 @@ import torch.nn.functional as F
from . import TransformerDecoderLayer, LayerNorm from . import TransformerDecoderLayer, LayerNorm
from .transformer_encoder import relative_position_bucket from .transformer_encoder import relative_position_bucket
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
return t.fill_(float("-inf")) return t.fill_(float("-inf"))
def bulid_future_mask(seq_len): def bulid_future_mask(seq_len):
return torch.triu( return torch.triu(
fill_with_neg_inf(torch.zeros([seq_len, seq_len])), 1 fill_with_neg_inf(torch.zeros([seq_len, seq_len])), 1
) )
class TransformerDecoder(nn.Module): class TransformerDecoder(nn.Module):
def __init__( def __init__(
self, self,
...@@ -66,7 +69,7 @@ class TransformerDecoder(nn.Module): ...@@ -66,7 +69,7 @@ class TransformerDecoder(nn.Module):
activation_dropout=activation_dropout, activation_dropout=activation_dropout,
activation_fn=activation_fn, activation_fn=activation_fn,
post_ln=post_ln, post_ln=post_ln,
) )
for _ in range(decoder_layers) for _ in range(decoder_layers)
] ]
...@@ -77,7 +80,8 @@ class TransformerDecoder(nn.Module): ...@@ -77,7 +80,8 @@ class TransformerDecoder(nn.Module):
assert rel_pos_bins % 2 == 0 assert rel_pos_bins % 2 == 0
self.rel_pos_bins = rel_pos_bins self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos self.max_rel_pos = max_rel_pos
self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads) self.relative_attention_bias = nn.Embedding(
self.rel_pos_bins, self.attention_heads)
seq_len = self.max_seq_len seq_len = self.max_seq_len
context_position = torch.arange(seq_len, dtype=torch.long)[:, None] context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long)[None, :] memory_position = torch.arange(seq_len, dtype=torch.long)[None, :]
...@@ -98,7 +102,7 @@ class TransformerDecoder(nn.Module): ...@@ -98,7 +102,7 @@ class TransformerDecoder(nn.Module):
values = F.embedding(rp_bucket, self.relative_attention_bias.weight) values = F.embedding(rp_bucket, self.relative_attention_bias.weight)
values = values.permute([2, 0, 1]) values = values.permute([2, 0, 1])
return values.contiguous() return values.contiguous()
def get_future_mask(self, x, attn_mask): def get_future_mask(self, x, attn_mask):
if not self.auto_regressive: if not self.auto_regressive:
return attn_mask return attn_mask
...@@ -108,9 +112,12 @@ class TransformerDecoder(nn.Module): ...@@ -108,9 +112,12 @@ class TransformerDecoder(nn.Module):
self._future_mask = self._future_mask.type_as(x) self._future_mask = self._future_mask.type_as(x)
if attn_mask is None: if attn_mask is None:
ret = self._future_mask[:x.size(1), :x.size(1)] ret = self._future_mask[:x.size(1), :x.size(1)]
ret = ret.contiguous().unsqueeze(0).repeat(x.size(0)*self.attention_heads, 1, 1) ret = ret.contiguous().unsqueeze(0).repeat(
x.size(0)*self.attention_heads, 1, 1)
return ret
else: else:
assert list(attn_mask.size()) == [x.size(0) * self.attention_heads, x.size(1), x.size(1)] assert list(attn_mask.size()) == [x.size(
0) * self.attention_heads, x.size(1), x.size(1)]
return attn_mask + self._future_mask[:x.size(1), :x.size(1)] return attn_mask + self._future_mask[:x.size(1), :x.size(1)]
def forward( def forward(
...@@ -122,16 +129,17 @@ class TransformerDecoder(nn.Module): ...@@ -122,16 +129,17 @@ class TransformerDecoder(nn.Module):
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
encoder_attn_mask: Optional[torch.Tensor] = None, encoder_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
seq_len = emb.size(1) seq_len = emb.size(1)
x = self.emb_layer_norm(emb) x = self.emb_layer_norm(emb)
x = F.dropout(x, p=self.emb_dropout, training=self.training) x = F.dropout(x, p=self.emb_dropout, training=self.training)
# account for padding while computing the representation # account for padding while computing the representation
if padding_mask is not None: if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.size(0), 1, 1) if self.rel_pos else None rel_pos_bias = self.get_rel_pos_bias(x).repeat(
x.size(0), 1, 1) if self.rel_pos else None
if attn_mask is None: if attn_mask is None:
attn_mask = rel_pos_bias attn_mask = rel_pos_bias
...@@ -150,12 +158,12 @@ class TransformerDecoder(nn.Module): ...@@ -150,12 +158,12 @@ class TransformerDecoder(nn.Module):
) )
attn_mask = attn_mask.view(-1, seq_len, seq_len) attn_mask = attn_mask.view(-1, seq_len, seq_len)
padding_mask = None padding_mask = None
for layer in self.layers: for layer in self.layers:
x = layer(x, encoder_out=encoder_out, padding_mask=padding_mask, attn_bias=attn_mask, x = layer(x, encoder_out=encoder_out, padding_mask=padding_mask, attn_bias=attn_mask,
encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask) encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask)
if self.final_layer_norm != None: if self.final_layer_norm is not None:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
return x return x
...@@ -157,7 +157,7 @@ class TransformerEncoder(nn.Module): ...@@ -157,7 +157,7 @@ class TransformerEncoder(nn.Module):
for layer in self.layers: for layer in self.layers:
x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask) x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask)
if self.final_layer_norm != None: if self.final_layer_norm is not None:
x = self.final_layer_norm(x) x = self.final_layer_norm(x)
return x return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment