Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
1b7cbeb5
Commit
1b7cbeb5
authored
Nov 19, 2020
by
Jiezhong Qiu
Browse files
torch sparse for spmm
parent
47167bcc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
3 deletions
+89
-3
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+89
-3
No files found.
pytorch/mem_transformer.py
View file @
1b7cbeb5
...
...
@@ -7,6 +7,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch_sparse
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
...
...
@@ -82,6 +83,50 @@ class MoEPositionwiseFF(nn.Module):
return
output
# return output, relu_out.detach()
class
SparsePositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
SparsePositionwiseFF
,
self
).
__init__
()
print
(
"SparsePositionwiseFF"
)
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Dropout
(
dropout
)
)
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
relu_out
=
self
.
CoreNet_1
(
inp
).
view
(
-
1
,
self
.
d_inner
)
sparse_relu_out
=
torch_sparse
.
SparseTensor
.
from_dense
(
relu_out
)
core_out
=
torch_sparse
.
matmul
(
sparse_relu_out
,
self
.
W2
)
+
self
.
b2
core_out
=
self
.
dropout_final
(
core_out
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
...
...
@@ -96,7 +141,7 @@ class MultiHeadPositionwiseFF(nn.Module):
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
#
self.q_net = nn.Linear(d_model, d_model)
self
.
q_net
=
nn
.
Linear
(
d_model
,
d_model
)
self
.
k_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
,
d_head
))
self
.
k_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
n_head
,
d_inner
))
...
...
@@ -129,8 +174,8 @@ class MultiHeadPositionwiseFF(nn.Module):
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
#
head_q = self.q_net(inp)
head_q
=
inp
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
head_q
=
self
.
q_net
(
inp
)
head_q
=
head_q
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
n_head
,
self
.
d_head
)
# [.. x n_head x d_head]
attn_score
=
torch
.
einsum
(
'ibnd,nhd->ibnh'
,
(
head_q
,
self
.
k_weight
))
+
self
.
k_bias
# [.. x n_head x d_inner]
attn_score
=
F
.
relu
(
attn_score
)
...
...
@@ -148,6 +193,47 @@ class MultiHeadPositionwiseFF(nn.Module):
return
output
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
super
(
PositionwiseFF
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
CoreNet_1
=
nn
.
Sequential
(
nn
.
Linear
(
d_model
,
d_inner
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
CoreNet_2
=
nn
.
Sequential
(
nn
.
Dropout
(
dropout
),
nn
.
Linear
(
d_inner
,
d_model
),
nn
.
Dropout
(
dropout
),
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
def
forward
(
self
,
inp
):
if
self
.
pre_lnorm
:
##### layer normalization + positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
self
.
layer_norm
(
inp
))
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection
output
=
core_out
+
inp
else
:
##### positionwise feed-forward
relu_out
=
self
.
CoreNet_1
(
inp
)
core_out
=
self
.
CoreNet_2
(
relu_out
)
##### residual connection + layer normalization
output
=
self
.
layer_norm
(
inp
+
core_out
)
return
output
# return output, relu_out.detach()
class
PositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment