Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
fedc55ec
"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "db3b986be00c035f8ef585093371e312972dc414"
Commit
fedc55ec
authored
Aug 16, 2018
by
ngimel
Committed by
Myle Ott
Aug 16, 2018
Browse files
add end-of-stack normalizations in case normalize_before has been set (#244)
parent
f7f2dd01
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
0 deletions
+25
-0
fairseq/models/transformer.py
fairseq/models/transformer.py
+25
-0
No files found.
fairseq/models/transformer.py
View file @
fedc55ec
...
@@ -200,6 +200,10 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -200,6 +200,10 @@ class TransformerEncoder(FairseqEncoder):
TransformerEncoderLayer
(
args
)
TransformerEncoderLayer
(
args
)
for
i
in
range
(
args
.
encoder_layers
)
for
i
in
range
(
args
.
encoder_layers
)
])
])
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
normalize
=
args
.
encoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
# embed tokens and positions
# embed tokens and positions
...
@@ -220,6 +224,9 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -220,6 +224,9 @@ class TransformerEncoder(FairseqEncoder):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
encoder_padding_mask
)
x
=
layer
(
x
,
encoder_padding_mask
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
return
{
return
{
'encoder_out'
:
x
,
# T x B x C
'encoder_out'
:
x
,
# T x B x C
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
'encoder_padding_mask'
:
encoder_padding_mask
,
# B x T
...
@@ -245,6 +252,11 @@ class TransformerEncoder(FairseqEncoder):
...
@@ -245,6 +252,11 @@ class TransformerEncoder(FairseqEncoder):
if
'encoder.embed_positions.weights'
in
state_dict
:
if
'encoder.embed_positions.weights'
in
state_dict
:
del
state_dict
[
'encoder.embed_positions.weights'
]
del
state_dict
[
'encoder.embed_positions.weights'
]
state_dict
[
'encoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
(
1
)
state_dict
[
'encoder.embed_positions._float_tensor'
]
=
torch
.
FloatTensor
(
1
)
if
state_dict
.
get
(
'encoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
#earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
'encoder.version'
]
=
torch
.
Tensor
([
1
])
return
state_dict
return
state_dict
...
@@ -285,6 +297,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -285,6 +297,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
elif
not
self
.
share_input_output_embed
:
elif
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
normalize
=
args
.
decoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
=
None
,
incremental_state
=
None
):
def
forward
(
self
,
prev_output_tokens
,
encoder_out
=
None
,
incremental_state
=
None
):
# embed positions
# embed positions
...
@@ -317,6 +333,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -317,6 +333,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
incremental_state
,
incremental_state
,
)
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
x
=
x
.
transpose
(
0
,
1
)
...
@@ -354,6 +373,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
...
@@ -354,6 +373,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if
k
in
state_dict
:
if
k
in
state_dict
:
state_dict
[
'decoder.layers.{}.{}.{}'
.
format
(
i
,
new
,
m
)]
=
state_dict
[
k
]
state_dict
[
'decoder.layers.{}.{}.{}'
.
format
(
i
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
del
state_dict
[
k
]
if
state_dict
.
get
(
'decoder.version'
,
torch
.
Tensor
([
1
]))[
0
]
<
2
:
#earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
'decoder.version'
]
=
torch
.
Tensor
([
1
])
return
state_dict
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment