Commit 60c4081b authored by Myle Ott's avatar Myle Ott
Browse files

More improvements to weight init and FP16 support

parent 36e360d9
......@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
from torch import nn
from torch.nn import Parameter
......@@ -30,20 +28,21 @@ class MultiheadAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3*self.embed_dim, self.embed_dim))
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*self.embed_dim))
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(self.embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform(self.in_proj_weight.data)
nn.init.xavier_uniform(self.out_proj.weight.data)
nn.init.xavier_uniform(self.in_proj_weight)
nn.init.xavier_uniform(self.out_proj.weight)
if self.in_proj_bias is not None:
self.in_proj_bias.data.zero_()
nn.init.constant(self.in_proj_bias, 0.)
nn.init.constant(self.out_proj.bias, 0.)
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None,
......@@ -125,10 +124,10 @@ class MultiheadAttention(nn.Module):
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf,
)
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
......@@ -178,14 +177,13 @@ class MultiheadAttention(nn.Module):
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(tensor.new(dim, dim).fill_(-math.inf), 1)
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-math.inf), 1)
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
......@@ -194,10 +192,10 @@ class MultiheadAttention(nn.Module):
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
......
......@@ -375,3 +375,8 @@ def item(tensor):
if hasattr(tensor, '__getitem__'):
return tensor[0]
return tensor
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment