Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
58e949cf
Commit
58e949cf
authored
Jan 26, 2021
by
Rick Ho
Browse files
initial version to run with megatron
parent
f866ed0f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
12 deletions
+49
-12
fmoe/megatron.py
fmoe/megatron.py
+4
-3
fmoe/moe.py
fmoe/moe.py
+45
-9
No files found.
fmoe/megatron.py
View file @
58e949cf
...
@@ -5,8 +5,9 @@ from .moe import FFFN
...
@@ -5,8 +5,9 @@ from .moe import FFFN
def
create_moe_mlp
(
args
):
def
create_moe_mlp
(
args
):
assert
args
.
num_experts
%
args
.
model_parallel_size
==
0
,
'Num experts should be multiple of mp size'
assert
args
.
num_experts
%
args
.
model_parallel_size
==
0
,
'Num experts should be multiple of mp size'
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
fmoe
=
FFFN
(
num_experts
,
in_feat
=
args
.
hidden_size
,
fmoe
=
FFFN
(
num_experts
,
hidden_feat
=
args
.
hidden_size
*
4
,
out_feat
=
args
.
hidden_size
,
d_model
=
args
.
hidden_size
,
world_size
=
args
.
model_parallel_size
)
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
args
.
model_parallel_size
)
return
fmoe
return
fmoe
fmoe/moe.py
View file @
58e949cf
import
math
import
math
from
torch
import
nn
from
torch
import
nn
import
torch
import
torch
import
torch.nn.functional
as
F
from
.moe_function
import
moe
from
.moe_function
import
moe
...
@@ -27,20 +28,55 @@ class FMoE(nn.Module):
...
@@ -27,20 +28,55 @@ class FMoE(nn.Module):
class
FFFN
(
nn
.
Module
):
class
FFFN
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
hidden_feat
=
4096
,
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
out_feat
=
1024
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
):
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
super
(
FFFN
,
self
).
__init__
()
super
(
FFFN
,
self
).
__init__
()
self
.
htoh4
=
FMoE
(
num_expert
,
in_feat
,
hidden_feat
,
self
.
d_model
=
d_model
world_size
=
world_size
)
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
activation
=
activation
self
.
activation
=
activation
self
.
h4toh
=
FMoE
(
num_expert
,
hidden_feat
,
out_feat
,
self
.
top_k
=
top_k
self
.
pre_lnorm
=
pre_lnorm
self
.
htoh4
=
FMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
=
world_size
)
world_size
=
world_size
)
self
.
h4toh
=
FMoE
(
num_expert
,
d_hidden
,
d_model
,
world_size
=
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
,
gate
):
def
forward
(
self
,
inp
):
x
=
self
.
htoh4
(
inp
)
# import pdb; pdb.set_trace()
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
self
.
htoh4
(
inp
,
gate_top_k_idx
)
x
=
self
.
activation
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
)
x
=
self
.
h4toh
(
x
,
gate_top_k_idx
)
return
x
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
,
self
.
bias
class
BruteForceMoE
(
nn
.
Module
):
class
BruteForceMoE
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment