Commit 60c4081b authored by Myle Ott's avatar Myle Ott
Browse files

More improvements to weight init and FP16 support

parent 36e360d9
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
...@@ -30,20 +28,21 @@ class MultiheadAttention(nn.Module): ...@@ -30,20 +28,21 @@ class MultiheadAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self._mask = None self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3*self.embed_dim, self.embed_dim)) self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
if bias: if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*self.embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
else: else:
self.register_parameter('in_proj_bias', None) self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(self.embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform(self.in_proj_weight.data) nn.init.xavier_uniform(self.in_proj_weight)
nn.init.xavier_uniform(self.out_proj.weight.data) nn.init.xavier_uniform(self.out_proj.weight)
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
self.in_proj_bias.data.zero_() nn.init.constant(self.in_proj_bias, 0.)
nn.init.constant(self.out_proj.bias, 0.)
def forward(self, query, key, value, mask_future_timesteps=False, def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None, key_padding_mask=None, incremental_state=None,
...@@ -125,10 +124,10 @@ class MultiheadAttention(nn.Module): ...@@ -125,10 +124,10 @@ class MultiheadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf, float('-inf'),
) ).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
...@@ -178,14 +177,13 @@ class MultiheadAttention(nn.Module): ...@@ -178,14 +177,13 @@ class MultiheadAttention(nn.Module):
def buffered_mask(self, tensor): def buffered_mask(self, tensor):
dim = tensor.size(-1) dim = tensor.size(-1)
if self._mask is None: if self._mask is None:
self._mask = torch.triu(tensor.new(dim, dim).fill_(-math.inf), 1) self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim: if self._mask.size(0) < dim:
self._mask = torch.triu(self._mask.resize_(dim, dim).fill_(-math.inf), 1) self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim] return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation).""" """Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None: if input_buffer is not None:
for k in input_buffer.keys(): for k in input_buffer.keys():
...@@ -194,10 +192,10 @@ class MultiheadAttention(nn.Module): ...@@ -194,10 +192,10 @@ class MultiheadAttention(nn.Module):
def _get_input_buffer(self, incremental_state): def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state( return utils.get_incremental_state(
self, self,
incremental_state, incremental_state,
'attn_state', 'attn_state',
) or {} ) or {}
def _set_input_buffer(self, incremental_state, buffer): def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state( utils.set_incremental_state(
......
...@@ -375,3 +375,8 @@ def item(tensor): ...@@ -375,3 +375,8 @@ def item(tensor):
if hasattr(tensor, '__getitem__'): if hasattr(tensor, '__getitem__'):
return tensor[0] return tensor[0]
return tensor return tensor
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment