Commit 49c9895b authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in return future mask

parent 689e0b24
......@@ -11,14 +11,17 @@ import torch.nn.functional as F
from . import TransformerDecoderLayer, LayerNorm
from .transformer_encoder import relative_position_bucket
def fill_with_neg_inf(t):
return t.fill_(float("-inf"))
def bulid_future_mask(seq_len):
return torch.triu(
fill_with_neg_inf(torch.zeros([seq_len, seq_len])), 1
)
class TransformerDecoder(nn.Module):
def __init__(
self,
......@@ -77,7 +80,8 @@ class TransformerDecoder(nn.Module):
assert rel_pos_bins % 2 == 0
self.rel_pos_bins = rel_pos_bins
self.max_rel_pos = max_rel_pos
self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads)
self.relative_attention_bias = nn.Embedding(
self.rel_pos_bins, self.attention_heads)
seq_len = self.max_seq_len
context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long)[None, :]
......@@ -108,9 +112,12 @@ class TransformerDecoder(nn.Module):
self._future_mask = self._future_mask.type_as(x)
if attn_mask is None:
ret = self._future_mask[:x.size(1), :x.size(1)]
ret = ret.contiguous().unsqueeze(0).repeat(x.size(0)*self.attention_heads, 1, 1)
ret = ret.contiguous().unsqueeze(0).repeat(
x.size(0)*self.attention_heads, 1, 1)
return ret
else:
assert list(attn_mask.size()) == [x.size(0) * self.attention_heads, x.size(1), x.size(1)]
assert list(attn_mask.size()) == [x.size(
0) * self.attention_heads, x.size(1), x.size(1)]
return attn_mask + self._future_mask[:x.size(1), :x.size(1)]
def forward(
......@@ -131,7 +138,8 @@ class TransformerDecoder(nn.Module):
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.size(0), 1, 1) if self.rel_pos else None
rel_pos_bias = self.get_rel_pos_bias(x).repeat(
x.size(0), 1, 1) if self.rel_pos else None
if attn_mask is None:
attn_mask = rel_pos_bias
......@@ -155,7 +163,7 @@ class TransformerDecoder(nn.Module):
x = layer(x, encoder_out=encoder_out, padding_mask=padding_mask, attn_bias=attn_mask,
encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask)
if self.final_layer_norm != None:
if self.final_layer_norm is not None:
x = self.final_layer_norm(x)
return x
......@@ -157,7 +157,7 @@ class TransformerEncoder(nn.Module):
for layer in self.layers:
x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask)
if self.final_layer_norm != None:
if self.final_layer_norm is not None:
x = self.final_layer_norm(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment