Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
437afda2
Commit
437afda2
authored
Jan 29, 2021
by
Rick Ho
Browse files
reconstruct fmoe nn module
parent
5e0af68d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
158 additions
and
55 deletions
+158
-55
fmoe/__init__.py
fmoe/__init__.py
+2
-1
fmoe/fmoe.py
fmoe/fmoe.py
+100
-0
fmoe/megatron.py
fmoe/megatron.py
+56
-2
fmoe/moe.py
fmoe/moe.py
+0
-52
No files found.
fmoe/__init__.py
View file @
437afda2
from
.moe
import
FMoE
,
BruteForceMoE
from
.moe
import
BruteForceMoE
from
.fmoe
import
FMoELinear
,
FMoENaiveGate
,
FMoETransformerMLP
fmoe/fmoe.py
0 → 100644
View file @
437afda2
from
.fmoe_functions
import
*
import
torch.nn
as
nn
class
FMoELinear
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
):
super
(
FMoE
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
in_feat
=
in_feat
self
.
out_feat
=
out_feat
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_expert
,
out_feat
,
in_feat
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
i
in
range
(
self
.
num_expert
):
linear
=
nn
.
Linear
(
in_features
=
self
.
in_feat
,
out_features
=
self
.
out_feat
)
self
.
weight
.
data
[
i
]
=
linear
.
weight
.
data
def
forward
(
self
,
inp
,
fwd_expert_count
):
return
MOELinear
.
apply
(
inp
,
self
.
weight
,
fwd_expert_count
)
class
FMoENaiveGate
(
nn
.
module
):
def
__init__
(
self
,
num_expert
=
32
,
world_size
=
1
,
top_k
=
2
):
super
(
FMoENaiveGate
,
self
).
__init__
()
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
def
forward
(
self
,
inp
):
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
return
gate_top_k_idx
,
gate_score
def
_fmoe_full_forward
(
inp
,
gate
,
linears
,
activation
,
num_expert
,
world_size
):
(
pos
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
fwd_batch_size
)
=
moe_prepare_forward
(
gate
,
num_expert
,
world_size
)
x
=
MOEScatter
.
apply
(
inp
,
pos
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
for
i
,
l
in
enumerate
(
linears
):
if
i
:
x
=
activation
(
x
)
x
=
l
(
x
)
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
inp
.
shape
[
0
],
world_size
)
return
x
class
FMoETransformerMLP
(
nn
.
module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
super
(
FMoETransformerMLP
,
self
).
__init__
()
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
activation
=
activation
self
.
pre_lnorm
=
pre_lnorm
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
gate
=
FMoENaivegate
(
num_expert
,
world_size
,
top_k
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
):
# import pdb; pdb.set_trace()
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
x
=
_fmoe_full_forward
(
inp
,
gate_top_k_idx
,
[
self
.
htoh4
,
self
.
h4toh
],
self
.
activation
,
self
.
num_expert
,
self
.
world_size
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
,
self
.
bias
fmoe/megatron.py
View file @
437afda2
from
torch
import
nn
from
torch
import
nn
from
.moe
import
FFFN
from
.moe
import
FMoE
from
.moe_function
import
moe
from
.fmoe
import
FMoETransformerMLP
class
FFFN
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
super
(
FFFN
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
activation
=
activation
self
.
top_k
=
top_k
self
.
pre_lnorm
=
pre_lnorm
self
.
htoh4
=
FMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
=
world_size
)
self
.
h4toh
=
FMoE
(
num_expert
,
d_hidden
,
d_model
,
world_size
=
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
):
# import pdb; pdb.set_trace()
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
self
.
htoh4
(
inp
,
gate_top_k_idx
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
,
gate_top_k_idx
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
,
self
.
bias
def
create_moe_mlp
(
args
):
def
create_moe_mlp
(
args
):
assert
args
.
num_experts
%
args
.
model_parallel_size
==
0
,
'Num experts should be multiple of mp size'
assert
args
.
num_experts
%
args
.
model_parallel_size
==
0
,
'Num experts should be multiple of mp size'
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
num_experts
=
args
.
num_experts
//
args
.
model_parallel_size
fmoe
=
F
FFN
(
num_experts
,
fmoe
=
F
MoETransformerMLP
(
num_experts
,
d_model
=
args
.
hidden_size
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
args
.
model_parallel_size
)
world_size
=
args
.
model_parallel_size
)
...
...
fmoe/moe.py
View file @
437afda2
...
@@ -27,58 +27,6 @@ class FMoE(nn.Module):
...
@@ -27,58 +27,6 @@ class FMoE(nn.Module):
return
moe
(
inp
,
gate
.
int
(),
self
.
weight
,
self
.
world_size
)
return
moe
(
inp
,
gate
.
int
(),
self
.
weight
,
self
.
world_size
)
class
FFFN
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
pre_lnorm
=
False
):
super
(
FFFN
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
activation
=
activation
self
.
top_k
=
top_k
self
.
pre_lnorm
=
pre_lnorm
self
.
htoh4
=
FMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
=
world_size
)
self
.
h4toh
=
FMoE
(
num_expert
,
d_hidden
,
d_model
,
world_size
=
world_size
)
self
.
gate
=
nn
.
Linear
(
d_model
,
num_expert
*
world_size
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
torch
.
zeros
(
d_model
,
dtype
=
torch
.
float32
))
def
forward
(
self
,
inp
):
# import pdb; pdb.set_trace()
residual
=
inp
if
self
.
pre_lnorm
:
inp
=
self
.
layer_norm
(
inp
)
gate
=
self
.
gate
(
inp
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate
,
k
=
self
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
# [.. x top_k]
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
self
.
top_k
)
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
).
unsqueeze
(
1
)
gate_top_k_idx
=
gate_top_k_idx
.
view
(
-
1
)
# (BxLxtop_k)
inp
=
inp
.
view
(
-
1
,
self
.
d_model
).
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
# (BxLxtop_k) x d_model
x
=
self
.
htoh4
(
inp
,
gate_top_k_idx
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
,
gate_top_k_idx
)
core_out
=
x
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# (BxL) x top_k x d_model
core_out
=
torch
.
bmm
(
gate_score
,
core_out
)
# (BxL) x 1 x d_model
core_out
=
core_out
.
view
(
residual
.
size
(
0
),
residual
.
size
(
1
),
self
.
d_model
)
output
=
core_out
+
residual
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
return
output
,
self
.
bias
class
BruteForceMoE
(
nn
.
Module
):
class
BruteForceMoE
(
nn
.
Module
):
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
def
__init__
(
self
,
num_expert
=
32
,
in_feat
=
1024
,
out_feat
=
1024
,
world_size
=
0
):
world_size
=
0
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment