Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
9c92be55
"vscode:/vscode.git/clone" did not exist on "4be188389993b02443c3fd2570a4540293ee6421"
Commit
9c92be55
authored
Feb 05, 2021
by
Rick Ho
Browse files
fit fmoe in transformer-xl
parent
5e9bb2e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
79 deletions
+19
-79
examples/transformer-xl/mem_transformer.py
examples/transformer-xl/mem_transformer.py
+17
-77
examples/transformer-xl/scripts/getdata.sh
examples/transformer-xl/scripts/getdata.sh
+2
-2
No files found.
examples/transformer-xl/mem_transformer.py
View file @
9c92be55
...
@@ -9,8 +9,6 @@ import torch.nn as nn
...
@@ -9,8 +9,6 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
# import torch_sparse
# import torch_sparse
from
cuda.moe
import
MOELayer
sys
.
path
.
append
(
'utils'
)
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
...
@@ -33,81 +31,8 @@ class PositionalEmbedding(nn.Module):
...
@@ -33,81 +31,8 @@ class PositionalEmbedding(nn.Module):
else
:
else
:
return
pos_emb
[:,
None
,:]
return
pos_emb
[:,
None
,:]
class
CustomizedMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
2
,
num_expert
=
64
):
super
(
CustomizedMoEPositionwiseFF
,
self
).
__init__
()
print
(
"CustomizedMoEPositionwiseFF num_expert=%d top_k=%d"
%
(
num_expert
,
top_k
))
self
.
top_k
=
top_k
assert
num_expert
>=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
)
self
.
moe1
=
MOELayer
(
num_expert
=
num_expert
,
in_feat
=
d_model
+
1
,
out_feat
=
d_inner
)
self
.
moe2
=
MOELayer
(
num_expert
=
num_expert
,
in_feat
=
d_inner
+
1
,
out_feat
=
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
pass
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1) # (BxL) x 1 x top_k
# gate_top_k_idx = gate_top_k_idx.view(-1, self.top_k)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
# (BxL) x 1 x top_k
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
#core_out = []
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
inp
=
F
.
pad
(
inp
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
x
=
self
.
moe1
(
inp
,
gate_top_k_idx
)
x
=
self
.
dropout
(
F
.
relu
(
x
))
x
=
F
.
pad
(
x
,
pad
=
(
0
,
1
),
mode
=
'constant'
,
value
=
1.0
)
x
=
self
.
moe2
(
x
,
gate_top_k_idx
)
x
=
self
.
dropout
(
x
)
# (BxLxtop_k) x d_model
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
"""
for i in range(self.top_k):
gate_idx = gate_top_k_idx[:, i].contiguous()
x = self.moe1(inp, gate_idx)
x = self.dropout(F.relu(x))
x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x = self.moe2(x, gate_idx)
x = self.dropout(x) # (BxL) x d_model
core_out.append(x.unsqueeze(1)) # (BxL) x 1 x d_model
core_out = torch.cat(core_out, dim=1) # (BxL) x top_k x d_model
"""
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
# A baseline naive slow implementation
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFFRaw
,
self
).
__init__
()
super
(
MoEPositionwiseFFRaw
,
self
).
__init__
()
...
@@ -158,7 +83,7 @@ class MoEPositionwiseFFRaw(nn.Module):
...
@@ -158,7 +83,7 @@ class MoEPositionwiseFFRaw(nn.Module):
output
=
self
.
layer_norm
(
output
)
output
=
self
.
layer_norm
(
output
)
return
output
return
output
# return output, relu_out.detach()
def
my_topk
(
x
,
k
,
inplace
=
True
):
def
my_topk
(
x
,
k
,
inplace
=
True
):
y
=
x
if
inplace
else
x
.
clone
()
y
=
x
if
inplace
else
x
.
clone
()
...
@@ -891,6 +816,21 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
...
@@ -891,6 +816,21 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
return
output
return
output
from
fmoe
import
FMoETransformerMLP
class
CustomizedMoEPositionwiseFF
(
FMoETransformerMLP
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
):
def
activation
(
x
):
return
self
.
dropout
(
F
.
relu
(
x
))
super
().
__init__
(
num_expert
=
8
,
d_model
=
d_model
,
d_hidden
=
d_inner
,
pre_lnorm
=
pre_lnorm
,
activation
=
activation
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
x
,
bias
=
super
().
forward
(
x
)
return
x
+
bias
class
DecoderLayer
(
nn
.
Module
):
class
DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
def
__init__
(
self
,
n_head
,
d_model
,
d_head
,
d_inner
,
dropout
,
**
kwargs
):
super
(
DecoderLayer
,
self
).
__init__
()
super
(
DecoderLayer
,
self
).
__init__
()
...
...
examples/transformer-xl/scripts/getdata.sh
View file @
9c92be55
echo
"=== Acquiring datasets ==="
echo
"=== Acquiring datasets ==="
echo
"---"
echo
"---"
mkdir
-p
data
mkdir
-p
../
data
cd
data
cd
../
data
if
[[
!
-d
'wikitext-2'
]]
;
then
if
[[
!
-d
'wikitext-2'
]]
;
then
echo
"- Downloading WikiText-2 (WT2)"
echo
"- Downloading WikiText-2 (WT2)"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment