Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2338a26e
Commit
2338a26e
authored
Jan 03, 2021
by
Jiezhong Qiu
Browse files
incoporate customized cuda moe input xfmr-xl
parent
d036918c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
11 deletions
+82
-11
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+6
-6
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+74
-5
pytorch/run_enwik8_base.sh
pytorch/run_enwik8_base.sh
+2
-0
No files found.
pytorch/cuda/moe.py
View file @
2338a26e
...
...
@@ -5,8 +5,6 @@ import torch
import
moe_cuda
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
class
MOEFunction
(
Function
):
@
staticmethod
...
...
@@ -21,12 +19,12 @@ class MOEFunction(Function):
@
staticmethod
def
backward
(
ctx
,
grad_out
):
print
(
"grad_out"
,
grad_out
)
print
(
"input"
,
ctx
.
saved_tensors
[
0
])
#
print("grad_out", grad_out)
#
print("input", ctx.saved_tensors[0])
grad_inp
,
grad_weight
=
moe_cuda
.
backward
(
grad_out
.
contiguous
(),
*
ctx
.
saved_tensors
)
out_feat
,
in_feat
=
grad_weight
.
size
()[
1
:]
print
(
"grad_weight_column_major"
,
grad_weight
.
flatten
())
#
print("grad_weight_column_major", grad_weight.flatten())
grad_weight_row_major
=
grad_weight
.
view
(
-
1
,
in_feat
,
out_feat
).
transpose
(
-
1
,
-
2
).
contiguous
().
view
(
-
1
,
out_feat
,
in_feat
)
return
grad_inp
,
None
,
grad_weight_row_major
...
...
@@ -47,7 +45,7 @@ class MOELayer(nn.Module):
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
gate
):
return
MOEFunction
.
apply
(
inp
,
gate
,
self
.
weight
)
return
MOEFunction
.
apply
(
inp
,
gate
.
int
()
,
self
.
weight
)
class
MOELayer_raw
(
nn
.
Module
):
...
...
@@ -75,6 +73,8 @@ class MOELayer_raw(nn.Module):
def
test
():
torch
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
)
batch_size
=
4
num_expert
=
4
in_feat
=
2
...
...
pytorch/mem_transformer.py
View file @
2338a26e
...
...
@@ -9,6 +9,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
# import torch_sparse
from
cuda.moe
import
MOELayer
sys
.
path
.
append
(
'utils'
)
from
proj_adaptive_softmax
import
ProjectedAdaptiveLogSoftmax
from
log_uniform_sampler
import
LogUniformSampler
,
sample_logits
...
...
@@ -31,9 +33,76 @@ class PositionalEmbedding(nn.Module):
else
:
return
pos_emb
[:,
None
,:]
class
MoEPositionwiseFF
(
nn
.
Module
):
class
CustomizedMoEPositionwiseFF
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
2
,
num_expert
=
4
):
super
(
CustomizedMoEPositionwiseFF
,
self
).
__init__
()
print
(
"CustomizedMoEPositionwiseFF num_expert=%d top_k=%d"
%
(
num_expert
,
top_k
))
self
.
top_k
=
top_k
assert
num_expert
>=
top_k
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
dropout
=
dropout
self
.
gate
=
nn
.
Linear
(
d_model
,
d_inner
)
self
.
moe1
=
MOELayer
(
num_expert
=
num_expert
,
in_feat
=
d_model
,
out_feat
=
d_inner
)
self
.
moe2
=
MOELayer
(
num_expert
=
num_expert
,
in_feat
=
d_inner
,
out_feat
=
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
pre_lnorm
=
pre_lnorm
self
.
dropout
=
nn
.
Dropout
(
dropout
)
self
.
reset_parameter
()
def
reset_parameter
(
self
):
pass
def
forward
(
self
,
inp
):
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
# (BxL) x 1 x top_k
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
,
self
.
top_k
)
core_out
=
[]
inp
=
inp
.
view
(
-
1
,
self
.
d_model
)
# inp = F.pad(inp, pad=(0, 1), mode='constant', value=1.0)
for
i
in
range
(
self
.
top_k
):
print
(
"top %d"
%
i
)
gate_idx
=
gate_top_k_idx
[:,
i
].
contiguous
()
print
(
inp
.
size
(),
gate_idx
.
size
())
x
=
self
.
moe1
(
inp
,
gate_idx
)
x
=
self
.
dropout
(
F
.
relu
(
x
))
# x = F.pad(x, pad=(0, 1), mode='constant', value=1.0)
x
=
self
.
moe2
(
x
,
gate_idx
)
x
=
self
.
dropout
(
x
)
# (BxL) x d_model
output
.
append
(
x
.
unsqueeze
(
1
))
# (BxL) x 1 x d_model
core_out
=
torch
.
cat
(
core_out
,
dim
=
1
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
class
MoEPositionwiseFFRaw
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
False
,
top_k
=
64
):
super
(
MoEPositionwiseFF
,
self
).
__init__
()
super
(
MoEPositionwiseFF
Raw
,
self
).
__init__
()
print
(
"MoEPositionwiseFF"
)
self
.
top_k
=
top_k
...
...
@@ -820,7 +889,7 @@ class DecoderLayer(nn.Module):
self
.
dec_attn
=
MultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
# self.dec_attn = ExtendedMultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self
.
pos_ff
=
MultiHeadHierarchical
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
Customized
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -840,7 +909,7 @@ class RelLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
MultiHeadHierarchical
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
Customized
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r_emb
,
r_w_bias
,
r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
@@ -861,7 +930,7 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self
.
dec_attn
=
RelPartialLearnableMultiHeadAttn
(
n_head
,
d_model
,
d_head
,
dropout
,
**
kwargs
)
self
.
pos_ff
=
MultiHeadHierarchical
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
self
.
pos_ff
=
Customized
MoEPositionwiseFF
(
d_model
,
d_inner
,
dropout
,
pre_lnorm
=
kwargs
.
get
(
'pre_lnorm'
))
def
forward
(
self
,
dec_inp
,
r
,
r_w_bias
,
r_r_bias
,
dec_attn_mask
=
None
,
mems
=
None
):
...
...
pytorch/run_enwik8_base.sh
View file @
2338a26e
#!/bin/bash
export
LD_LIBRARY_PATH
=
/home/jiezhong/miniconda3/lib:/usr/local/cuda/lib64:
$LD_LIBRARY_PATH
if
[[
$1
==
'train'
]]
;
then
echo
'Run training...'
python train.py
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment