Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
36e360d9
Commit
36e360d9
authored
Apr 05, 2018
by
Myle Ott
Browse files
Use PyTorch LayerNorm and improve weight init
parent
fc830685
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
85 deletions
+25
-85
fairseq/models/transformer.py
fairseq/models/transformer.py
+25
-30
fairseq/modules/__init__.py
fairseq/modules/__init__.py
+0
-2
fairseq/modules/layer_norm.py
fairseq/modules/layer_norm.py
+0
-53
No files found.
fairseq/models/transformer.py
View file @
36e360d9
...
...
@@ -12,7 +12,7 @@ import torch.nn.functional as F
from
fairseq.data
import
LanguagePairDataset
from
fairseq.modules
import
(
LayerNorm
,
LearnedPositionalEmbedding
,
MultiheadAttention
,
LearnedPositionalEmbedding
,
MultiheadAttention
,
SinusoidalPositionalEmbedding
,
)
from
fairseq
import
utils
...
...
@@ -117,15 +117,6 @@ class TransformerEncoder(FairseqEncoder):
for
i
in
range
(
args
.
encoder_layers
)
])
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
name
,
p
in
self
.
named_parameters
():
if
name
.
endswith
(
'weight'
):
nn
.
init
.
xavier_uniform
(
p
.
data
)
elif
name
.
endswith
(
'bias'
):
p
.
data
.
zero_
()
def
forward
(
self
,
src_tokens
,
src_lengths
):
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
src_tokens
)
...
...
@@ -188,15 +179,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
name
,
p
in
self
.
named_parameters
():
if
name
.
endswith
(
'weight'
):
nn
.
init
.
xavier_uniform
(
p
.
data
)
elif
name
.
endswith
(
'bias'
):
p
.
data
.
zero_
()
nn
.
init
.
normal
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**-
0.5
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
# embed positions
...
...
@@ -220,11 +203,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
# decoder layers
for
layer
in
self
.
layers
:
x
,
attn
=
layer
(
x
,
encoder_out
[
'encoder_out'
],
encoder_out
[
'encoder_padding_mask'
],
incremental_state
,
)
x
,
encoder_out
[
'encoder_out'
],
encoder_out
[
'encoder_padding_mask'
],
incremental_state
,
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
...
...
@@ -271,8 +254,8 @@ class TransformerEncoderLayer(nn.Module):
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
nn
.
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
2
)])
def
forward
(
self
,
x
,
encoder_padding_mask
):
...
...
@@ -317,8 +300,8 @@ class TransformerDecoderLayer(nn.Module):
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
nn
.
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
layer_norms
=
nn
.
ModuleList
([
LayerNorm
(
self
.
embed_dim
)
for
i
in
range
(
3
)])
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
):
...
...
@@ -373,14 +356,26 @@ class TransformerDecoderLayer(nn.Module):
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
m
.
weight
.
data
.
normal_
(
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
return
m
def
LayerNorm
(
embedding_dim
):
m
=
nn
.
LayerNorm
(
embedding_dim
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform
(
m
.
weight
)
nn
.
init
.
constant
(
m
.
bias
,
0.
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
,
learned
=
False
):
if
learned
:
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
m
.
weight
.
data
.
normal_
(
0
,
0.
1
)
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.
5
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
left_pad
,
init_size
=
num_embeddings
)
return
m
...
...
fairseq/modules/__init__.py
View file @
36e360d9
...
...
@@ -8,7 +8,6 @@
from
.beamable_mm
import
BeamableMM
from
.conv_tbc
import
ConvTBC
from
.grad_multiply
import
GradMultiply
from
.layer_norm
import
LayerNorm
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.linearized_convolution
import
LinearizedConvolution
from
.multihead_attention
import
MultiheadAttention
...
...
@@ -18,7 +17,6 @@ __all__ = [
'BeamableMM'
,
'ConvTBC'
,
'GradMultiply'
,
'LayerNorm'
,
'LearnedPositionalEmbedding'
,
'LinearizedConvolution'
,
'MultiheadAttention'
,
...
...
fairseq/modules/layer_norm.py
deleted
100644 → 0
View file @
fc830685
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
from
torch.autograd
import
Variable
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
LayerNorm
(
nn
.
Module
):
"""Applies Layer Normalization over the last dimension."""
def
__init__
(
self
,
features
,
eps
=
1e-5
):
super
().
__init__
()
self
.
features
=
features
self
.
eps
=
eps
self
.
gain
=
nn
.
Parameter
(
torch
.
ones
(
features
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
features
))
self
.
dummy
=
None
self
.
w
=
None
self
.
b
=
None
def
forward
(
self
,
input
):
shape
=
input
.
size
()
# In order to force the cudnn path, everything needs to be
# contiguous. Hence the check here and reallocation below.
if
not
input
.
is_contiguous
():
input
=
input
.
contiguous
()
input
=
input
.
view
(
1
,
-
1
,
shape
[
-
1
])
# Expand w and b buffers if necessary.
n
=
input
.
size
(
1
)
cur
=
self
.
dummy
.
numel
()
if
self
.
dummy
is
not
None
else
0
if
cur
==
0
:
self
.
dummy
=
input
.
data
.
new
(
n
)
self
.
w
=
input
.
data
.
new
(
n
).
fill_
(
1
)
self
.
b
=
input
.
data
.
new
(
n
).
zero_
()
elif
n
>
cur
:
self
.
dummy
.
resize_
(
n
)
self
.
w
.
resize_
(
n
)
self
.
w
[
cur
:
n
].
fill_
(
1
)
self
.
b
.
resize_
(
n
)
self
.
b
[
cur
:
n
].
zero_
()
dummy
=
self
.
dummy
[:
n
]
w
=
Variable
(
self
.
w
[:
n
])
b
=
Variable
(
self
.
b
[:
n
])
output
=
F
.
batch_norm
(
input
,
dummy
,
dummy
,
w
,
b
,
True
,
0.
,
self
.
eps
)
return
torch
.
addcmul
(
self
.
bias
,
1
,
output
.
view
(
*
shape
),
self
.
gain
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment