Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
74b6908f
Unverified
Commit
74b6908f
authored
Nov 23, 2021
by
Rick Ho
Committed by
GitHub
Nov 23, 2021
Browse files
Merge pull request #90 from xptree/multi_input
MoE with multiple inputs and multiple outputs
parents
b652e8d8
0abea7b2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
277 additions
and
39 deletions
+277
-39
fmoe/layers.py
fmoe/layers.py
+110
-39
requirements.txt
requirements.txt
+1
-0
tests/test_mimo.py
tests/test_mimo.py
+166
-0
No files found.
fmoe/layers.py
View file @
74b6908f
r
"""
FMoE core layer
"""
import
tree
import
torch
import
torch.nn
as
nn
...
...
@@ -10,7 +11,6 @@ from .functions import AllGather, Slice
from
.gates
import
NaiveGate
def
mark_module_parallel_comm
(
module
,
comm
):
r
"""
Mark all parameters in `module` as doing data parallel in `comm`, where
...
...
@@ -42,22 +42,37 @@ def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
topk
=
1
if
len
(
gate
.
shape
)
==
2
:
topk
=
gate
.
shape
[
1
]
x
=
MOEScatter
.
apply
(
inp
,
pos
//
topk
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
)
def
scatter_func
(
tensor
):
return
MOEScatter
.
apply
(
tensor
,
pos
//
topk
,
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
world_size
,
)
x
=
tree
.
map_structure
(
scatter_func
,
inp
)
x
=
expert_fn
(
x
,
fwd_expert_count
)
out_batch_size
=
inp
.
shape
[
0
]
out_batch_size
=
tree
.
flatten
(
inp
)[
0
]
.
shape
[
0
]
if
len
(
gate
.
shape
)
==
2
:
out_batch_size
*=
gate
.
shape
[
1
]
x
=
MOEGather
.
apply
(
x
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
world_size
)
return
x
def
gather_func
(
tensor
):
return
MOEGather
.
apply
(
tensor
,
pos
,
local_expert_count
,
global_expert_count
,
out_batch_size
,
world_size
,
)
outp
=
tree
.
map_structure
(
gather_func
,
x
)
return
outp
class
FMoE
(
nn
.
Module
):
...
...
@@ -84,7 +99,7 @@ class FMoE(nn.Module):
num_expert
=
32
,
d_model
=
1024
,
world_size
=
1
,
mp_group
=
None
,
# being deprecated
mp_group
=
None
,
# being deprecated
slice_group
=
None
,
moe_group
=
None
,
top_k
=
2
,
...
...
@@ -101,7 +116,7 @@ class FMoE(nn.Module):
self
.
slice_group
=
slice_group
if
mp_group
is
not
None
:
print
(
'
[Warning] mp_group is being deprecated
'
)
print
(
"
[Warning] mp_group is being deprecated
"
)
self
.
slice_group
=
mp_group
if
self
.
slice_group
is
None
:
self
.
slice_size
=
1
...
...
@@ -116,8 +131,7 @@ class FMoE(nn.Module):
self
.
experts_fused
=
False
self
.
num_expert
=
num_expert
=
len
(
expert
)
elif
expert
is
not
None
:
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts_fused
=
False
else
:
self
.
experts_fused
=
True
...
...
@@ -159,52 +173,109 @@ class FMoE(nn.Module):
mark_module_parallel_comm
(
self
.
experts
,
comm
)
mark_module_parallel_comm
(
self
.
gate
,
"gate"
)
def
forward
(
self
,
inp
):
def
forward
(
self
,
moe_
inp
):
r
"""
The FMoE module first computes gate output, and then conduct MoE forward
according to the gate. The score of the selected gate given by the
expert is multiplied to the experts' output tensors as a weight.
"""
moe_inp_batch_size
=
tree
.
flatten
(
tree
.
map_structure
(
lambda
tensor
:
tensor
.
shape
[
0
],
moe_inp
)
)
assert
all
(
[
batch_size
==
moe_inp_batch_size
[
0
]
for
batch_size
in
moe_inp_batch_size
]
),
"MoE inputs must have the same batch size"
if
self
.
world_size
>
1
:
ensure_comm
(
inp
,
self
.
moe_group
)
def
ensure_comm_func
(
tensor
):
ensure_comm
(
tensor
,
self
.
moe_group
)
tree
.
map_structure
(
ensure_comm_func
,
moe_inp
)
if
self
.
slice_size
>
1
:
inp
=
Slice
.
apply
(
inp
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
def
slice_func
(
tensor
):
return
Slice
.
apply
(
tensor
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
moe_inp
=
tree
.
map_structure
(
slice_func
,
moe_inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
moe_inp
)
if
self
.
gate_hook
is
not
None
:
self
.
gate_hook
(
gate_top_k_idx
,
gate_score
,
None
)
# delete masked tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
# TODO: to fix
def
delete_mask_func
(
tensor
):
# to: (BxL') x d_model
tensor
=
tensor
[
mask
==
0
,
:]
return
tensor
mask
=
self
.
mask
.
view
(
-
1
)
# to: (BxL') x d_model
inp
=
inp
[
mask
==
0
,
:]
moe_inp
=
tree
.
map_structure
(
delete_mask_func
,
moe_inp
)
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
moe_inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
)
# recover deleted tensors
if
self
.
mask
is
not
None
and
self
.
mask_dict
is
not
None
:
# to: (BxL') x top_k x d_model
fwd
=
fwd
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x top_k x d_model
x
=
torch
.
zeros
(
mask
.
shape
[
0
],
self
.
top_k
,
self
.
d_model
,
device
=
fwd
.
device
,
dtype
=
fwd
.
dtype
)
# recover
x
[
mask
==
0
]
=
fwd
for
k
,
v
in
self
.
mask_dict
.
items
():
x
[
mask
==
k
]
=
v
def
recover_func
(
tensor
):
# to: (BxL') x top_k x dim
dim
=
tensor
.
shape
[
-
1
]
tensor
=
tensor
.
view
(
-
1
,
self
.
top_k
,
dim
)
# to: (BxL) x top_k x d_model
x
=
torch
.
zeros
(
mask
.
shape
[
0
],
self
.
top_k
,
dim
,
device
=
tensor
.
device
,
dtype
=
tensor
.
dtype
,
)
# recover
x
[
mask
==
0
]
=
tensor
for
k
,
v
in
self
.
mask_dict
.
items
():
x
[
mask
==
k
]
=
v
return
x
moe_outp
=
tree
.
map_structure
(
recover_func
,
fwd
)
else
:
x
=
fwd
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
gate_score
=
gate_score
.
view
(
x
.
shape
[
0
],
1
,
self
.
top_k
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
def
view_func
(
tensor
):
dim
=
tensor
.
shape
[
-
1
]
tensor
=
tensor
.
view
(
-
1
,
self
.
top_k
,
dim
)
return
tensor
moe_outp
=
tree
.
map_structure
(
view_func
,
fwd
)
gate_score
=
gate_score
.
view
(
-
1
,
1
,
self
.
top_k
)
def
bmm_func
(
tensor
):
dim
=
tensor
.
shape
[
-
1
]
tensor
=
torch
.
bmm
(
gate_score
,
tensor
).
reshape
(
-
1
,
dim
)
return
tensor
moe_outp
=
tree
.
map_structure
(
bmm_func
,
moe_outp
)
if
self
.
slice_size
>
1
:
x
=
AllGather
.
apply
(
x
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
return
x
def
all_gather_func
(
tensor
):
return
AllGather
.
apply
(
tensor
,
self
.
slice_rank
,
self
.
slice_size
,
self
.
slice_group
)
moe_outp
=
tree
.
map_structure
(
all_gather_func
,
moe_outp
)
moe_outp_batch_size
=
tree
.
flatten
(
tree
.
map_structure
(
lambda
tensor
:
tensor
.
shape
[
0
],
moe_outp
)
)
assert
all
(
[
batch_size
==
moe_outp_batch_size
[
0
]
for
batch_size
in
moe_outp_batch_size
]
),
"MoE outputs must have the same batch size"
return
moe_outp
requirements.txt
View file @
74b6908f
torch
numpy
ninja
dm-tree
tests/test_mimo.py
0 → 100644
View file @
74b6908f
import
sys
import
pytest
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
fmoe.gates
import
NaiveGate
from
fmoe.layers
import
FMoE
from
fmoe.linear
import
FMoELinear
from
fmoe.megatron.layers
import
_megatron_init_method
def
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
,
precision
=
1e-3
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
err
=
(
mo
-
ro
).
abs
().
max
()
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
if
err
>
precision
:
sys
.
stderr
.
write
(
f
"===========
{
name
}
moe out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
mo
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
raw out ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
"
.
format
(
ro
))
sys
.
stderr
.
write
(
f
"===========
{
name
}
diff ==============
\n
"
)
sys
.
stderr
.
write
(
"{}
\n
{}
\n
"
.
format
((
mo
-
ro
).
abs
(),
err
))
assert
False
class
MyExpert
(
nn
.
Module
):
r
"""
An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker.
"""
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
r
"""
First expand input to 4h (the hidden size is variable, but is called h4
for convenience). Then perform activation. Finally shirink back to h.
"""
if
type
(
inp
)
==
dict
:
x
=
inp
[
"x"
]
y
=
inp
[
"y"
]
elif
type
(
inp
)
==
list
:
x
=
inp
[
0
]
y
=
inp
[
1
]
else
:
raise
NotImplementedError
x
=
self
.
htoh4
(
x
,
fwd_expert_count
)
x
=
self
.
activation
(
x
)
x
=
self
.
h4toh
(
x
,
fwd_expert_count
)
y
=
self
.
htoh4
(
y
,
fwd_expert_count
)
y
=
self
.
activation
(
y
)
y
=
self
.
h4toh
(
y
,
fwd_expert_count
)
if
type
(
inp
)
==
dict
:
ret
=
{
"x"
:
x
,
"y"
:
y
}
elif
type
(
inp
)
==
list
:
ret
=
[
x
,
y
]
return
ret
class
MyGate
(
NaiveGate
):
def
__init__
(
self
,
d_model
,
num_expert
,
world_size
,
top_k
=
2
):
super
().
__init__
(
d_model
,
num_expert
,
world_size
,
top_k
)
def
forward
(
self
,
inp
,
return_all_scores
=
False
):
if
type
(
inp
)
==
dict
:
x
=
inp
[
"x"
]
elif
type
(
inp
)
==
list
:
x
=
inp
[
0
]
else
:
raise
NotImplementedError
return
super
().
forward
(
x
,
return_all_scores
)
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
):
super
().
__init__
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
MyGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
,
)
self
.
experts
=
MyExpert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
rng
=
np
.
random
.
default_rng
(
1234
)
_megatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
1.0
)
_megatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
1.0
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"data_type"
,
[
"torch.FloatTensor"
,
"torch.DoubleTensor"
,
"torch.HalfTensor"
]
)
@
pytest
.
mark
.
parametrize
(
"list_input"
,
[
False
,
True
])
def
test_fmoe_mimo_linear
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
rank
,
world_size
,
mp_group
,
dp_group
,
world_group
,
data_type
,
list_input
,
activation
=
torch
.
nn
.
functional
.
gelu
,
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
moe
=
MyMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
d_hidden
=
4
*
d_model
,
world_size
=
world_size
,
mp_group
=
mp_group
,
top_k
=
top_k
,
activation
=
activation
,
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
inp
=
[
x
,
x
.
clone
()]
if
list_input
else
{
"x"
:
x
,
"y"
:
x
.
clone
()}
moe_out
=
moe
(
inp
)
if
list_input
:
_assert_numerical
([
"x"
],
[
moe_out
[
0
]],
[
moe_out
[
1
]],
rank
)
else
:
_assert_numerical
([
"x"
],
[
moe_out
[
"x"
]],
[
moe_out
[
"y"
]],
rank
)
if
__name__
==
"__main__"
:
test_fmoe_mimo_linear
(
batch_size
=
2
,
num_expert
=
2
,
d_model
=
2
,
top_k
=
2
,
d_hidden
=
16
,
rank
=
0
,
world_size
=
1
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
data_type
=
torch
.
float32
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment