Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
823f9c2e
Commit
823f9c2e
authored
Nov 25, 2020
by
Jiezhong Qiu
Browse files
recover sparse ffn
parent
0f3e63eb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
36 deletions
+39
-36
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+39
-36
No files found.
pytorch/mem_transformer.py
View file @
823f9c2e
...
@@ -7,7 +7,7 @@ import numpy as np
...
@@ -7,7 +7,7 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
#
import torch_sparse
import
torch_sparse
sys
.
path
.
append
(
'utils'
)
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
...
@@ -114,6 +114,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
...
@@ -114,6 +114,7 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_middle
=
nn
.
Dropout
(
dropout
*
ratio
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
self
.
scale
=
1
/
(
d_model
**
0.5
)
self
.
reset_parameter
()
self
.
reset_parameter
()
def
reset_parameter
(
self
):
def
reset_parameter
(
self
):
...
@@ -131,6 +132,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
...
@@ -131,6 +132,8 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
block
=
self
.
block_net
(
inp
)
block
=
self
.
block_net
(
inp
)
block_val
,
block_idx
=
torch
.
topk
(
block
,
k
=
self
.
top_block
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
block_val
,
block_idx
=
torch
.
topk
(
block
,
k
=
self
.
top_block
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
block_val
.
mul_
(
self
.
scale
)
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
gate
=
F
.
softmax
(
block_val
,
dim
=-
1
)
...
@@ -154,51 +157,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
...
@@ -154,51 +157,51 @@ class HierarchicalMoEPositionwiseFF(nn.Module):
return
output
return
output
#
class SparsePositionwiseFF(nn.Module):
class
SparsePositionwiseFF
(
nn
.
Module
):
#
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
#
super(SparsePositionwiseFF, self).__init__()
super
(
SparsePositionwiseFF
,
self
).
__init__
()
#
print("SparsePositionwiseFF")
print
(
"SparsePositionwiseFF"
)
#
self.d_model = d_model
self
.
d_model
=
d_model
#
self.d_inner = d_inner
self
.
d_inner
=
d_inner
#
self.dropout = dropout
self
.
dropout
=
dropout
#
self.CoreNet_1 = nn.Sequential(
self
.
CoreNet_1
=
nn
.
Sequential
(
#
nn.Linear(d_model, d_inner),
nn
.
Linear
(
d_model
,
d_inner
),
#
nn.ReLU(inplace=True),
nn
.
ReLU
(
inplace
=
True
),
#
nn.Dropout(dropout)
nn
.
Dropout
(
dropout
)
#
)
)
#
self.W2 = nn.Parameter(torch.Tensor(d_inner, d_model))
self
.
W2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_inner
,
d_model
))
#
self.b2 = nn.Parameter(torch.Tensor(d_model))
self
.
b2
=
nn
.
Parameter
(
torch
.
Tensor
(
d_model
))
#
self.layer_norm = nn.LayerNorm(d_model)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
#
self.pre_lnorm = pre_lnorm
self
.
pre_lnorm
=
pre_lnorm
#
self.dropout_final = nn.Dropout(dropout)
self
.
dropout_final
=
nn
.
Dropout
(
dropout
)
#
self.reset_parameter()
self
.
reset_parameter
()
#
def reset_parameter(self):
def
reset_parameter
(
self
):
#
temp_Linear = nn.Linear(self.d_inner, self.d_model)
temp_Linear
=
nn
.
Linear
(
self
.
d_inner
,
self
.
d_model
)
#
self.W2.data = temp_Linear.weight.data.transpose(0, 1)
self
.
W2
.
data
=
temp_Linear
.
weight
.
data
.
transpose
(
0
,
1
)
#
self.b2.data = temp_Linear.bias.data
self
.
b2
.
data
=
temp_Linear
.
bias
.
data
#
def forward(self, inp):
def
forward
(
self
,
inp
):
#
residual = inp
residual
=
inp
#
if self.pre_lnorm:
if
self
.
pre_lnorm
:
#
inp = self.layer_norm(inp)
inp
=
self
.
layer_norm
(
inp
)
#
relu_out = self.CoreNet_1(inp).view(-1, self.d_inner)
relu_out
=
self
.
CoreNet_1
(
inp
).
view
(
-
1
,
self
.
d_inner
)
#
sparse_relu_out = torch_sparse.SparseTensor.from_dense(relu_out)
sparse_relu_out
=
torch_sparse
.
SparseTensor
.
from_dense
(
relu_out
)
#
core_out = torch_sparse.matmul(sparse_relu_out, self.W2) + self.b2
core_out
=
torch_sparse
.
matmul
(
sparse_relu_out
,
self
.
W2
)
+
self
.
b2
#
core_out = core_out.view(inp.size(0), inp.size(1), self.d_model)
core_out
=
core_out
.
view
(
inp
.
size
(
0
),
inp
.
size
(
1
),
self
.
d_model
)
#
core_out = self.dropout_final(core_out)
core_out
=
self
.
dropout_final
(
core_out
)
#
output = core_out + residual
output
=
core_out
+
residual
#
if not self.pre_lnorm:
if
not
self
.
pre_lnorm
:
#
output = self.layer_norm(output)
output
=
self
.
layer_norm
(
output
)
#
return output
return
output
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
class
MultiHeadPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
n_head
=
2
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment