Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
47167bcc
Commit
47167bcc
authored
Nov 19, 2020
by
Jiezhong Qiu
Browse files
fix bug
parent
82fe21d3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+6
-6
No files found.
pytorch/mem_transformer.py
View file @
47167bcc
...
...
@@ -96,7 +96,7 @@ class MultiHeadPositionwiseFF(nn.Module):
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
q_net
=
nn
.
Linear
(
d_model
,
d_model
)
#
self.q_net = nn.Linear(d_model, d_model)
self
.
k_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
,
d_head
))
self
.
k_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
))
...
...
@@ -104,7 +104,7 @@ class MultiHeadPositionwiseFF(nn.Module):
self
.
v_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
,
d_inner
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_head
))
self
.
o_net
=
nn
.
Linear
(
d_model
,
d_model
)
#
self.o_net = nn.Linear(d_model, d_model)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
...
...
@@ -129,8 +129,8 @@ class MultiHeadPositionwiseFF(nn.Module):
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
head_q
=
self
.
q_net
(
inp
)
head_q
=
head_q
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
#
head_q = self.q_net(inp)
head_q
=
inp
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
attn_score
=
torch
.
einsum
(
'ibnd,nhd->ibnh'
,
(
head_q
,
self
.
k_weight
))
+
self
.
k_bias
# [.. x n_head x d_inner]
attn_score
=
F
.
relu
(
attn_score
)
...
...
@@ -139,8 +139,8 @@ class MultiHeadPositionwiseFF(nn.Module):
attn_vec
=
torch
.
einsum
(
'ibnh,ndh->ibnd'
,
(
attn_score
,
self
.
v_weight
))
+
self
.
v_bias
attn_vec
=
attn_vec
.
contiguous
().
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
core_out
=
self
.
o_net
(
attn_vec
)
core_out
=
self
.
dropout
(
core_out
)
#
core_out = self.o_net(attn_vec)
core_out
=
self
.
dropout
(
attn_vec
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment