Commit 36e360d9 authored by Myle Ott's avatar Myle Ott
Browse files

Use PyTorch LayerNorm and improve weight init

parent fc830685
......@@ -12,7 +12,7 @@ import torch.nn.functional as F
from fairseq.data import LanguagePairDataset
from fairseq.modules import (
LayerNorm, LearnedPositionalEmbedding, MultiheadAttention,
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
)
from fairseq import utils
......@@ -117,15 +117,6 @@ class TransformerEncoder(FairseqEncoder):
for i in range(args.encoder_layers)
])
self.reset_parameters()
def reset_parameters(self):
for name, p in self.named_parameters():
if name.endswith('weight'):
nn.init.xavier_uniform(p.data)
elif name.endswith('bias'):
p.data.zero_()
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
......@@ -188,15 +179,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
self.reset_parameters()
def reset_parameters(self):
for name, p in self.named_parameters():
if name.endswith('weight'):
nn.init.xavier_uniform(p.data)
elif name.endswith('bias'):
p.data.zero_()
nn.init.normal(self.embed_out, mean=0, std=embed_dim**-0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions
......@@ -220,11 +203,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
......@@ -271,8 +254,8 @@ class TransformerEncoderLayer(nn.Module):
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = nn.Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask):
......@@ -317,8 +300,8 @@ class TransformerDecoderLayer(nn.Module):
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = nn.Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = nn.Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
......@@ -373,14 +356,26 @@ class TransformerDecoderLayer(nn.Module):
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.normal_(mean=0, std=embedding_dim**-0.5)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform(m.weight)
nn.init.constant(m.bias, 0.)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
m.weight.data.normal_(0, 0.1)
nn.init.normal(m.weight, mean=0, std=embedding_dim**-0.5)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m
......
......@@ -8,7 +8,6 @@
from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply
from .layer_norm import LayerNorm
from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
......@@ -18,7 +17,6 @@ __all__ = [
'BeamableMM',
'ConvTBC',
'GradMultiply',
'LayerNorm',
'LearnedPositionalEmbedding',
'LinearizedConvolution',
'MultiheadAttention',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class LayerNorm(nn.Module):
"""Applies Layer Normalization over the last dimension."""
def __init__(self, features, eps=1e-5):
super().__init__()
self.features = features
self.eps = eps
self.gain = nn.Parameter(torch.ones(features))
self.bias = nn.Parameter(torch.zeros(features))
self.dummy = None
self.w = None
self.b = None
def forward(self, input):
shape = input.size()
# In order to force the cudnn path, everything needs to be
# contiguous. Hence the check here and reallocation below.
if not input.is_contiguous():
input = input.contiguous()
input = input.view(1, -1, shape[-1])
# Expand w and b buffers if necessary.
n = input.size(1)
cur = self.dummy.numel() if self.dummy is not None else 0
if cur == 0:
self.dummy = input.data.new(n)
self.w = input.data.new(n).fill_(1)
self.b = input.data.new(n).zero_()
elif n > cur:
self.dummy.resize_(n)
self.w.resize_(n)
self.w[cur:n].fill_(1)
self.b.resize_(n)
self.b[cur:n].zero_()
dummy = self.dummy[:n]
w = Variable(self.w[:n])
b = Variable(self.b[:n])
output = F.batch_norm(input, dummy, dummy, w, b, True, 0., self.eps)
return torch.addcmul(self.bias, 1, output.view(*shape), self.gain)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment