Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
28ba2d28
Unverified
Commit
28ba2d28
authored
May 23, 2021
by
Colin
Committed by
GitHub
May 23, 2021
Browse files
mask and experts list
parent
baae8fb9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
152 additions
and
8 deletions
+152
-8
.gitignore
.gitignore
+2
-0
fmoe/layers.py
fmoe/layers.py
+33
-6
fmoe/transformer.py
fmoe/transformer.py
+4
-0
tests/moe.py
tests/moe.py
+5
-1
tests/test_numerical.py
tests/test_numerical.py
+102
-0
tests/test_zero.py
tests/test_zero.py
+6
-1
No files found.
.gitignore
View file @
28ba2d28
...
...
@@ -11,3 +11,5 @@ build
*swp
logs
dist
**/.DS_Store
.idea
fmoe/layers.py
View file @
28ba2d28
...
...
@@ -132,6 +132,8 @@ class FMoE(nn.Module):
gate
=
NaiveGate
,
expert
=
None
,
gate_hook
=
None
,
mask
=
None
,
mask_dict
=
None
,
):
super
().
__init__
()
self
.
num_expert
=
num_expert
...
...
@@ -145,14 +147,20 @@ class FMoE(nn.Module):
self
.
mp_size
=
mp_group
.
size
()
self
.
mp_rank
=
mp_group
.
rank
()
self
.
top_k
=
top_k
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
if
expert
is
not
None
:
if
type
(
expert
)
is
list
:
self
.
experts
=
nn
.
ModuleList
([
e
(
d_model
)
for
e
in
expert
])
self
.
experts_fused
=
False
self
.
num_expert
=
num_expert
=
len
(
expert
)
elif
expert
is
not
None
:
self
.
experts
=
nn
.
ModuleList
([
expert
(
d_model
)
for
_
in
range
(
num_expert
)])
self
.
experts_fused
=
False
else
:
self
.
experts_fused
=
True
self
.
gate
=
gate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate_hook
=
gate_hook
self
.
mask
=
mask
self
.
mask_dict
=
mask_dict
def
expert_fn
(
self
,
inp
,
fwd_expert_count
):
r
"""
...
...
@@ -196,14 +204,33 @@ class FMoE(nn.Module):
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
x
=
_fmoe_general_global_forward
(
inp
,
# delete masked tensors
if
self
.
mask
!=
None
and
self
.
mask_dict
!=
None
:
mask
=
self
.
mask
.
view
(
-
1
)
# to: (BxL') x d_model
inp
=
inp
[
mask
==
0
,
:]
gate_top_k_idx
=
gate_top_k_idx
[
mask
==
0
,
:]
fwd
=
_fmoe_general_global_forward
(
inp
,
gate_top_k_idx
,
self
.
expert_fn
,
self
.
num_expert
,
self
.
world_size
)
x
=
x
.
view
(
inp
.
shape
[
0
],
self
.
top_k
,
self
.
d_model
)
gate_score
=
gate_score
.
view
(
inp
.
shape
[
0
],
1
,
self
.
top_k
)
# recover deleted tensors
if
self
.
mask
!=
None
and
self
.
mask_dict
!=
None
:
# to: (BxL') x top_k x d_model
fwd
=
fwd
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
# to: (BxL) x top_k x d_model
x
=
torch
.
zeros
(
mask
.
shape
[
0
],
self
.
top_k
,
self
.
d_model
)
# recover
x
[
mask
==
0
]
=
fwd
for
k
,
v
in
self
.
mask_dict
.
items
():
x
[
mask
==
k
]
=
v
else
:
x
=
fwd
.
view
(
-
1
,
self
.
top_k
,
self
.
d_model
)
gate_score
=
gate_score
.
view
(
x
.
shape
[
0
],
1
,
self
.
top_k
)
x
=
torch
.
bmm
(
gate_score
,
x
).
reshape
(
-
1
,
self
.
d_model
)
if
self
.
mp_size
>
1
:
...
...
fmoe/transformer.py
View file @
28ba2d28
...
...
@@ -49,6 +49,8 @@ class FMoETransformerMLP(FMoE):
top_k
=
2
,
expert_dp_comm
=
"none"
,
gate_hook
=
None
,
mask
=
None
,
mask_dict
=
None
,
):
super
().
__init__
(
num_expert
=
num_expert
,
...
...
@@ -58,6 +60,8 @@ class FMoETransformerMLP(FMoE):
world_size
=
world_size
,
mp_group
=
mp_group
,
gate_hook
=
gate_hook
,
mask
=
mask
,
mask_dict
=
mask_dict
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
self
.
mp_rank
...
...
tests/moe.py
View file @
28ba2d28
...
...
@@ -55,7 +55,11 @@ class BruteForceMoE(nn.Module):
self
.
num_expert
=
num_expert
self
.
d_model
=
d_model
self
.
top_k
=
top_k
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
*
world_size
)]
if
type
(
expert
)
is
list
:
self
.
experts
=
[
e
(
d_model
)
for
e
in
expert
]
self
.
num_expert
=
num_expert
=
len
(
expert
)
else
:
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
*
world_size
)]
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
...
...
tests/test_numerical.py
View file @
28ba2d28
...
...
@@ -384,6 +384,107 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
_assert_numerical
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
[
NaiveExpert
for
_
in
range
(
4
)],
[
LinearExpert
,
NaiveExpert
,
LinearExpert
,
NaiveExpert
,
LinearExpert
,
NaiveExpert
,
LinearExpert
,
NaiveExpert
]
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
def
test_fmoe_experts
(
batch_size
,
num_expert
,
d_model
,
top_k
,
expert
:
Union
[
Type
[
nn
.
Module
],
str
],
rank
,
world_size
,
mp_group
,
dp_group
,
world_group
,
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
if
isinstance
(
expert
,
str
):
expert
=
globals
()[
expert
]
moe
=
FMoE
(
num_expert
=
num_expert
,
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp_group
=
mp_group
,
expert
=
expert
,
top_k
=
top_k
,
).
cuda
()
moe_raw
=
BruteForceMoE
(
expert
=
expert
,
num_expert
=
num_expert
,
d_model
=
d_model
,
world_size
=
world_size
,
top_k
=
top_k
,
).
cuda
()
if
world_size
==
1
:
for
expert_moe
,
expert_raw
in
zip
(
moe
.
experts
,
moe_raw
.
experts
):
for
para_moe
,
para_raw
in
zip
(
expert_moe
.
parameters
(),
expert_raw
.
parameters
()
):
para_raw
.
data
=
para_moe
.
data
.
clone
()
else
:
assert
len
(
moe
.
experts
)
>=
1
for
idx
,
para
in
enumerate
(
moe
.
experts
[
0
].
parameters
()):
para_tensor
=
torch
.
cat
(
[
list
(
expert
.
parameters
())[
idx
].
unsqueeze
(
0
)
for
expert
in
moe
.
experts
]
)
para_array
=
[
torch
.
empty_like
(
para_tensor
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
para_array
,
para_tensor
)
para_tensor_gathered
=
torch
.
cat
(
para_array
,
dim
=
0
)
assert
para_tensor_gathered
.
shape
[
0
]
==
len
(
moe_raw
.
experts
)
for
expertID
in
range
(
para_tensor_gathered
.
shape
[
0
]):
list
(
moe_raw
.
experts
[
expertID
].
parameters
())[
idx
].
data
=
para_tensor_gathered
[
expertID
]
moe_out
,
raw_out
,
moe_grad_in
,
raw_grad_in
=
_perform_forward
(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
def
get_experts_grad
(
experts
:
List
[
nn
.
Module
]):
return
torch
.
stack
(
[
torch
.
stack
(
[
p
.
grad
.
sum
()
if
p
.
grad
is
not
None
else
torch
.
zeros
(
1
).
cuda
()
for
p
in
item
.
parameters
()
]
).
sum
()
for
item
in
experts
]
)
moe_grad
,
raw_grad
=
(
get_experts_grad
(
moe
.
experts
),
get_experts_grad
(
moe_raw
.
experts
),
)
if
world_size
>
1
:
torch
.
distributed
.
all_reduce
(
raw_grad
)
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
raw_grad
=
raw_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
moe_out_list
=
[
moe_out
,
moe_grad
,
moe_grad_in
]
raw_out_list
=
[
raw_out
,
raw_grad
,
raw_grad_in
]
names
=
[
"forward"
,
"backward"
,
"grad_in"
]
_assert_numerical
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
if
__name__
==
"__main__"
:
test_fmoe_linear
(
batch_size
=
2
,
...
...
@@ -396,4 +497,5 @@ if __name__ == "__main__":
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
data_type
=
torch
.
float32
,
)
tests/test_zero.py
View file @
28ba2d28
...
...
@@ -51,8 +51,13 @@ def test_zero_transformer(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
def
_test_zero_transformer
(
num_expert
=
2
,
batch_size
=
4
,
d_hidden
=
8
,
world_size
=
1
):
inp
=
torch
.
rand
(
batch_size
,
d_hidden
).
cuda
()
mask
=
torch
.
zeros
(
inp
.
shape
[
0
],
dtype
=
torch
.
long
)
mask
[
1
]
=
1
mask_dict
=
{
1
:
torch
.
zeros
(
d_hidden
).
cuda
()
}
model
=
FMoETransformerMLP
(
num_expert
,
d_hidden
,
d_hidden
*
4
,
world_size
,
gate
=
ConstantGate
).
cuda
()
gate
=
ConstantGate
,
mask
=
mask
,
mask_dict
=
mask_dict
).
cuda
()
oup
=
model
(
inp
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment