Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a0d99fdb
Commit
a0d99fdb
authored
Jul 09, 2020
by
Szymon Migacz
Browse files
Fixed weight init for fused weight matrices in fused MHA by adding correct gain factor.
parent
1ff54b8f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
2 deletions
+14
-2
apex/contrib/multihead_attn/encdec_multihead_attn.py
apex/contrib/multihead_attn/encdec_multihead_attn.py
+7
-1
apex/contrib/multihead_attn/self_multihead_attn.py
apex/contrib/multihead_attn/self_multihead_attn.py
+7
-1
No files found.
apex/contrib/multihead_attn/encdec_multihead_attn.py
View file @
a0d99fdb
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
...
...
@@ -76,7 +78,11 @@ class EncdecMultiheadAttn(nn.Module):
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_q
)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_kv
)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_kv
,
gain
=
math
.
sqrt
(
1.5
))
nn
.
init
.
xavier_uniform_
(
self
.
out_proj_weight
)
if
self
.
bias
:
nn
.
init
.
constant_
(
self
.
in_proj_bias_q
,
0.
)
...
...
apex/contrib/multihead_attn/self_multihead_attn.py
View file @
a0d99fdb
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
...
...
@@ -98,7 +100,11 @@ class SelfMultiheadAttn(nn.Module):
nn
.
init
.
xavier_uniform_
(
self
.
k_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
v_weight
)
else
:
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight
)
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight
,
gain
=
math
.
sqrt
(
2
))
nn
.
init
.
xavier_uniform_
(
self
.
out_proj_weight
)
if
self
.
bias
:
if
self
.
separate_qkv_params
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment