Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2d81858b
Commit
2d81858b
authored
Nov 23, 2021
by
Jiezhong Qiu
Browse files
allow netsd data struct as moe input
parent
b652e8d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
38 deletions
+97
-38
fmoe/layers.py
fmoe/layers.py
+96
-38
requirements.txt
requirements.txt
+1
-0
No files found.
fmoe/layers.py
View file @
2d81858b
r
"""
FMoE core layer
"""
import
tree
import
torch
import
torch.nn
as
nn
...
...
@@ -10,7 +11,6 @@ from .functions import AllGather, Slice
from
.gates
import
NaiveGate
def
mark_module_parallel_comm
(
module
,
comm
):
r
"""
Mark all parameters in `module` as doing data parallel in `comm`, where
...
...
@@ -42,22 +42,39 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
topk
=
gate
.
shape
[
1
]
x
=
MOEScatter
.
apply
(
inp
,
pos
//
topk
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
def
scatter_func
(
inp_tensor
):
tensor
=
MOEScatter
.
apply
(
inp_tensor
,
pos
//
topk
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
,
)
return
tensor
x
=
tree
.
map_structure
(
scatter_func
,
inp
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
out_batch_size
=
inp
.
shape
[
0
]
if
len
(
gate
.
shape
)
==
2
:
out_batch_size
*=
gate
.
shape
[
1
]
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
world_size
)
return
x
def
gatter_func
(
outp_tensor
):
tensor
=
MOEGather
.
apply
(
outp_tensor
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
world_size
,
)
return
tensor
outp
=
tree
.
map_structure
(
gatter_func
,
x
)
return
outp
class
FMoE
(
nn
.
Module
):
...
...
@@ -84,7 +101,7 @@ class FMoE(nn.Module):
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
# being deprecated
mp_group
=
None
,
# being deprecated
slice_group
=
None
,
moe_group
=
None
,
top_k
=
2
,
...
...
@@ -101,7 +118,7 @@ class FMoE(nn.Module):
self
.
slice_group
=
slice_group
if
mp_group
is
not
None
:
print
(
'
[Warning] mp_group is being deprecated
'
)
print
(
"
[Warning] mp_group is being deprecated
"
)
self
.
slice_group
=
mp_group
if
self
.
slice_group
is
None
:
self
.
slice_size
=
1
...
...
@@ -116,8 +133,7 @@ class FMoE(nn.Module):
self
.
experts_fused
=
False
self
.
num_expert
=
num_expert
=
len
(
expert
)
elif
expert
is
not
None
:
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts_fused
=
False
else
:
self
.
experts_fused
=
True
...
...
@@ -159,52 +175,94 @@ class FMoE(nn.Module):
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
gate
,
"gate"
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
moe_inp
,
non_moe_inp
=
None
):
r
"""
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
if
self
.
world_size
>
1
:
ensure_comm
(
inp
,
self
.
moe_group
)
def
ensure_comm_func
(
tensor
):
ensure_comm
(
tensor
,
self
.
moe_group
)
tree
.
map_structure
(
ensure_comm_func
,
moe_inp
)
if
self
.
slice_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
def
slice_func
(
tensor
):
return
Slice
.
apply
(
tensor
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
moe_inp
=
tree
.
map_structure
(
slice_func
,
moe_inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
moe_inp
)
if
self
.
gate_hook
is
not
None
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
None
)
# TODO: to fix
def
delete_mask_func
(
tensor
):
# to: (BxL') x d_model
tensor
=
tensor
[
mask
==
0
,
:]
return
tensor
# delete masked tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
mask
=
self
.
mask
.
view
(
-
1
)
# to: (BxL') x d_model
inp
=
inp
[
mask
==
0
,
:]
moe_inp
=
tree
.
map_structure
(
delete_mask_func
,
moe_inp
)
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
)
# recover deleted tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
# to: (BxL') x top_k x d_model
fwd
=
fwd
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x top_k x d_model
x
=
torch
.
zeros
(
mask
.
shape
[
0
],
self
.
top_k
,
self
.
d_model
,
device
=
fwd
.
device
,
dtype
=
fwd
.
dtype
)
# recover
x
[
mask
==
0
]
=
fwd
for
k
,
v
in
self
.
mask_dict
.
items
():
x
[
mask
==
k
]
=
v
def
recover_func
(
tensor
):
# to: (BxL') x top_k x dim
dim
=
tensor
.
shape
[
-
1
]
tensor
=
tensor
.
view
(
-
1
,
self
.
top_k
,
dim
)
# to: (BxL) x top_k x d_model
x
=
torch
.
zeros
(
mask
.
shape
[
0
],
self
.
top_k
,
dim
,
device
=
tensor
.
device
,
dtype
=
tensor
.
dtype
,
)
# recover
x
[
mask
==
0
]
=
tensor
for
k
,
v
in
self
.
mask_dict
.
items
():
x
[
mask
==
k
]
=
v
return
x
moe_outp
=
tree
.
map_structure
(
recover_func
,
fwd
)
else
:
x
=
fwd
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
gate_score
=
gate_score
.
view
(
x
.
shape
[
0
],
1
,
self
.
top_k
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
def
view_func
(
tensor
):
dim
=
tensor
.
shape
[
-
1
]
tensor
=
tensor
.
view
(
-
1
,
self
.
top_k
,
dim
)
return
tensor
moe_outp
=
tree
.
map_structure
(
view_func
,
fwd
)
gate_score
=
gate_score
.
view
(
-
1
,
1
,
self
.
top_k
)
def
bmm_func
(
tensor
):
dim
=
tensor
.
shape
[
-
1
]
tensor
=
torch
.
bmm
(
gate_score
,
tensor
).
reshape
(
-
1
,
dim
)
return
tensor
moe_outp
=
tree
.
map_structure
(
bmm_func
,
moe_outp
)
if
self
.
slice_size
>
1
:
x
=
AllGather
.
apply
(
x
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
return
x
def
all_gather_func
(
tensor
):
return
AllGather
.
apply
(
tensor
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
moe_outp
=
tree
.
map_structure
(
all_gather_func
,
moe_outp
)
return
moe_outp
requirements.txt
View file @
2d81858b
torch
numpy
ninja
dm-tree
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment