Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
f39f411a
Commit
f39f411a
authored
Feb 22, 2021
by
Sengxian
Browse files
customized top_k and add test for localddp
parent
27c89b5a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
167 additions
and
24 deletions
+167
-24
fmoe/megatron.py
fmoe/megatron.py
+7
-1
tests/test_ddp.py
tests/test_ddp.py
+28
-10
tests/test_numerical.py
tests/test_numerical.py
+132
-13
No files found.
fmoe/megatron.py
View file @
f39f411a
...
@@ -24,6 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -24,6 +24,7 @@ class MegatronMLP(FMoETransformerMLP):
else
:
else
:
world_size
=
args
.
world_size
world_size
=
args
.
world_size
super
().
__init__
(
args
.
num_experts
,
super
().
__init__
(
args
.
num_experts
,
top_k
=
args
.
top_k
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
world_size
=
world_size
,
mp_group
=
group
)
world_size
=
world_size
,
mp_group
=
group
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
...
@@ -35,7 +36,7 @@ class MegatronMLP(FMoETransformerMLP):
...
@@ -35,7 +36,7 @@ class MegatronMLP(FMoETransformerMLP):
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
,
hidden_hidden_size
=
None
):
hidden_hidden_size
=
None
,
top_k
=
None
):
r
'''
r
'''
Replace MLP layers in a transformer-based model in Megatron by MoE.
Replace MLP layers in a transformer-based model in Megatron by MoE.
* `model` should be a standard Megatron model that has
* `model` should be a standard Megatron model that has
...
@@ -63,6 +64,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
...
@@ -63,6 +64,11 @@ def fmoefy(model, num_experts=None, distributed_experts=True,
elif
not
hasattr
(
args
,
'hidden_hidden_size'
):
elif
not
hasattr
(
args
,
'hidden_hidden_size'
):
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
args
.
hidden_hidden_size
=
args
.
hidden_size
*
4
if
top_k
is
not
None
:
args
.
top_k
=
top_k
elif
not
hasattr
(
args
,
'top_k'
):
args
.
top_k
=
2
# Set distributed_experts to None to use default setting in args
# Set distributed_experts to None to use default setting in args
if
distributed_experts
is
not
None
:
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
args
.
distributed_experts
=
distributed_experts
...
...
tests/test_ddp.py
View file @
f39f411a
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
test_numerical
import
test_fmoe
as
_test_fmoe
from
test_numerical
import
test_fmoe
as
_test_fmoe
from
test_numerical
import
test_fmoe_linear
as
_test_fmoe_linear
from
test_numerical
import
test_fmoe_linear
as
_test_fmoe_linear
from
test_numerical
import
_test_fmoe_local_ddp
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
):
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
):
...
@@ -78,6 +79,13 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
...
@@ -78,6 +79,13 @@ def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_siz
)
)
@
pytest
.
mark
.
parametrize
(
"mp_size"
,
[
1
,
2
])
def
test_fmoe_local_ddp
(
mp_size
):
_run_distributed
(
_test_fmoe_local_ddp
.
__name__
,
mp_size
*
2
,
{
"mp_size"
:
mp_size
},
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
if
len
(
sys
.
argv
)
>=
3
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
args
=
json
.
loads
(
sys
.
argv
[
2
])
...
@@ -87,20 +95,30 @@ if __name__ == "__main__":
...
@@ -87,20 +95,30 @@ if __name__ == "__main__":
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
args
[
"rank"
]
=
torch
.
distributed
.
get_rank
()
args
[
"rank"
]
=
torch
.
distributed
.
get_rank
()
args
[
"world_size"
]
=
torch
.
distributed
.
get_world_size
()
args
[
"world_size"
]
=
torch
.
distributed
.
get_world_size
()
args
[
"mp_group"
]
=
(
args
[
"mp_group"
]
=
[
[
torch
.
distributed
.
new_group
(
torch
.
distributed
.
new_group
(
ranks
=
[
j
*
args
[
"mp_size"
]
+
i
for
i
in
range
(
args
[
"mp_size"
])],
ranks
=
[
j
*
args
[
"mp_size"
]
+
i
for
i
in
range
(
args
[
"mp_size"
])],
backend
=
"nccl"
,
backend
=
"nccl"
,
)
)
for
j
in
range
(
args
[
"world_size"
]
//
args
[
"mp_size"
])
for
j
in
range
(
args
[
"world_size"
]
//
args
[
"mp_size"
])
][
args
[
"rank"
]
//
args
[
"mp_size"
]]
][
args
[
"rank"
]
//
args
[
"mp_size"
]]
args
[
"dp_group"
]
=
[
if
args
[
"mp_size"
]
>
1
torch
.
distributed
.
new_group
(
else
None
ranks
=
[
i
*
args
[
"mp_size"
]
+
j
for
i
in
range
(
args
[
"world_size"
]
//
args
[
"mp_size"
])
],
backend
=
"nccl"
,
)
for
j
in
range
(
args
[
"mp_size"
])
][
args
[
"rank"
]
%
args
[
"mp_size"
]]
args
[
"world_group"
]
=
torch
.
distributed
.
new_group
(
ranks
=
list
(
range
(
args
[
"world_size"
])),
backend
=
"nccl"
,
)
)
del
args
[
"mp_size"
]
del
args
[
"mp_size"
]
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
test_fmoe_local_ddp
(
mp_size
=
2
)
test_fmoe_linear_distributed
(
test_fmoe_linear_distributed
(
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
num_expert
=
4
,
top_k
=
2
,
batch_size
=
4
,
d_model
=
8
,
d_hidden
=
8
,
mp_size
=
2
)
)
tests/test_numerical.py
View file @
f39f411a
import
sys
import
sys
from
collections
import
OrderedDict
from
typing
import
List
,
Type
,
Union
from
typing
import
List
,
Type
,
Union
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
copy
import
deepcopy
from
fmoe.gates
import
NaiveGate
from
fmoe.gates
import
NaiveGate
from
fmoe.layers
import
FMoE
from
fmoe.layers
import
FMoE
from
fmoe.transformer
import
_Expert
from
fmoe.transformer
import
_Expert
from
fmoe.distributed
import
DistributedGroupedDataParallel
as
LocalDDP
from
moe
import
BruteForceMoELinear
,
BruteForceMoE
,
NaiveExpert
,
LinearExpert
from
moe
import
BruteForceMoELinear
,
BruteForceMoE
,
NaiveExpert
,
LinearExpert
...
@@ -53,15 +56,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
...
@@ -53,15 +56,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list, rank):
class
MyMoE
(
FMoE
):
class
MyMoE
(
FMoE
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
def
__init__
(
top_k
,
activation
):
self
,
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
):
super
().
__init__
(
super
().
__init__
(
num_expert
=
num_expert
,
num_expert
=
num_expert
,
d_model
=
d_model
,
d_model
=
d_model
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
world_size
=
world_size
,
world_size
=
world_size
,
mp_group
=
mp_group
,
mp_group
=
mp_group
,
top_k
=
top_k
top_k
=
top_k
,
)
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
...
@@ -74,6 +78,8 @@ class MyMoE(FMoE):
...
@@ -74,6 +78,8 @@ class MyMoE(FMoE):
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
def
test_fmoe_linear
(
def
test_fmoe_linear
(
num_expert
,
num_expert
,
top_k
,
top_k
,
...
@@ -83,13 +89,16 @@ def test_fmoe_linear(
...
@@ -83,13 +89,16 @@ def test_fmoe_linear(
rank
,
rank
,
world_size
,
world_size
,
mp_group
,
mp_group
,
dp_group
,
world_group
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
moe
=
MyMoE
(
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
moe
=
MyMoE
(
activation
).
cuda
()
num_expert
,
d_model
,
d_hidden
,
world_size
,
mp_group
,
top_k
,
activation
).
cuda
()
moe_raw
=
BruteForceMoELinear
(
moe_raw
=
BruteForceMoELinear
(
activation
=
activation
,
activation
=
activation
,
...
@@ -132,8 +141,20 @@ def test_fmoe_linear(
...
@@ -132,8 +141,20 @@ def test_fmoe_linear(
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
moe
,
moe_raw
,
batch_size
,
d_model
,
top_k
,
rank
,
mp_group
)
)
moe_out_list
=
moe_out
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
moe_out_list
=
(
raw_out_list
=
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
moe_out
,
moe
.
experts
.
htoh4
.
weight
.
grad
,
moe
.
experts
.
h4toh
.
weight
.
grad
,
moe
.
experts
.
htoh4
.
bias
.
grad
,
moe
.
experts
.
h4toh
.
bias
.
grad
,
)
raw_out_list
=
(
raw_out
,
moe_raw
.
weight_htoh4
.
grad
,
moe_raw
.
weight_h4toh
.
grad
,
moe_raw
.
bias_htoh4
.
grad
,
moe_raw
.
bias_h4toh
.
grad
,
)
if
world_size
>
1
:
if
world_size
>
1
:
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
=
raw_out_list
...
@@ -142,13 +163,27 @@ def test_fmoe_linear(
...
@@ -142,13 +163,27 @@ def test_fmoe_linear(
torch
.
distributed
.
all_reduce
(
htoh4_b_grad
)
torch
.
distributed
.
all_reduce
(
htoh4_b_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_b_grad
)
torch
.
distributed
.
all_reduce
(
h4toh_b_grad
)
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
mp_size
=
mp_group
.
size
()
if
mp_group
else
1
htoh4_w_grad
=
htoh4_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
htoh4_w_grad
=
(
h4toh_w_grad
=
h4toh_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
htoh4_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
htoh4_b_grad
=
htoh4_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
)
h4toh_b_grad
=
h4toh_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
h4toh_w_grad
=
(
h4toh_w_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
)
htoh4_b_grad
=
(
htoh4_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
)
h4toh_b_grad
=
(
h4toh_b_grad
[
rank
*
num_expert
:
(
rank
+
1
)
*
num_expert
]
/
mp_size
)
raw_out_list
=
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
raw_out_list
=
_
,
htoh4_w_grad
,
h4toh_w_grad
,
htoh4_b_grad
,
h4toh_b_grad
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
,
"htoh4 bias grad"
,
"h4toh bias grad"
]
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
,
"htoh4 bias grad"
,
"h4toh bias grad"
,
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
...
@@ -160,6 +195,8 @@ def test_fmoe_linear(
...
@@ -160,6 +195,8 @@ def test_fmoe_linear(
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"mp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dp_group"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"world_group"
,
[
None
])
def
test_fmoe
(
def
test_fmoe
(
batch_size
,
batch_size
,
num_expert
,
num_expert
,
...
@@ -167,8 +204,10 @@ def test_fmoe(
...
@@ -167,8 +204,10 @@ def test_fmoe(
top_k
,
top_k
,
expert
:
Union
[
Type
[
nn
.
Module
],
str
],
expert
:
Union
[
Type
[
nn
.
Module
],
str
],
rank
,
rank
,
mp_group
,
world_size
,
world_size
,
mp_group
,
dp_group
,
world_group
,
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
...
@@ -249,6 +288,82 @@ def test_fmoe(
...
@@ -249,6 +288,82 @@ def test_fmoe(
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
8
):
super
(
MyModule
,
self
).
__init__
()
self
.
model
=
nn
.
Sequential
(
OrderedDict
(
[
(
"linear1"
,
nn
.
Linear
(
dim
,
dim
)),
(
"relu1"
,
nn
.
ReLU
()),
(
"linear2"
,
nn
.
Linear
(
dim
,
dim
)),
(
"relu2"
,
nn
.
ReLU
()),
(
"linear3"
,
nn
.
Linear
(
dim
,
dim
)),
]
)
)
def
set_comm
(
self
):
for
p
in
self
.
model
.
_modules
[
"linear1"
].
parameters
():
setattr
(
p
,
"dp_comm"
,
"mp"
)
for
p
in
self
.
model
.
_modules
[
"linear2"
].
parameters
():
setattr
(
p
,
"dp_comm"
,
"dp"
)
for
p
in
self
.
model
.
_modules
[
"linear3"
].
parameters
():
setattr
(
p
,
"dp_comm"
,
"world"
)
def
forward
(
self
,
inp
):
return
self
.
model
(
inp
)
def
_test_fmoe_local_ddp
(
rank
,
world_size
,
mp_group
,
dp_group
,
world_group
):
batch_size
,
dim
=
4
,
8
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
model
=
MyModule
().
cuda
()
model_ddp
=
LocalDDP
(
deepcopy
(
model
),
mp_group
,
dp_group
,
world_group
)
model
.
set_comm
()
model_ddp
.
module
.
set_comm
()
inp
=
torch
.
randn
(
batch_size
,
dim
).
cuda
()
raw_out
=
model
(
inp
).
mean
()
ddp_out
=
model_ddp
(
inp
).
mean
()
raw_out
.
backward
()
ddp_out
.
backward
()
torch
.
distributed
.
all_reduce
(
model
.
model
.
_modules
[
"linear1"
].
weight
.
grad
.
data
,
group
=
mp_group
)
model
.
model
.
_modules
[
"linear1"
].
weight
.
grad
/=
mp_group
.
size
()
torch
.
distributed
.
all_reduce
(
model
.
model
.
_modules
[
"linear2"
].
weight
.
grad
.
data
,
group
=
dp_group
)
model
.
model
.
_modules
[
"linear2"
].
weight
.
grad
/=
dp_group
.
size
()
torch
.
distributed
.
all_reduce
(
model
.
model
.
_modules
[
"linear3"
].
weight
.
grad
.
data
,
group
=
world_group
)
model
.
model
.
_modules
[
"linear3"
].
weight
.
grad
/=
world_group
.
size
()
model_ddp
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
False
)
raw_out_list
=
[
model
.
model
.
_modules
[
"linear1"
].
weight
.
grad
,
model
.
model
.
_modules
[
"linear2"
].
weight
.
grad
,
model
.
model
.
_modules
[
"linear3"
].
weight
.
grad
,
]
ddp_out_list
=
[
model_ddp
.
module
.
model
.
_modules
[
"linear1"
].
weight
.
grad
,
model_ddp
.
module
.
model
.
_modules
[
"linear2"
].
weight
.
grad
,
model_ddp
.
module
.
model
.
_modules
[
"linear3"
].
weight
.
grad
,
]
names
=
[
"mp grad"
,
"dp grad"
,
"wp grad"
]
_assert_numercial
(
names
,
ddp_out_list
,
raw_out_list
,
rank
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_fmoe_linear
(
test_fmoe_linear
(
batch_size
=
4
,
batch_size
=
4
,
...
@@ -259,6 +374,8 @@ if __name__ == "__main__":
...
@@ -259,6 +374,8 @@ if __name__ == "__main__":
rank
=
0
,
rank
=
0
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
)
)
test_fmoe
(
test_fmoe
(
batch_size
=
4
,
batch_size
=
4
,
...
@@ -269,4 +386,6 @@ if __name__ == "__main__":
...
@@ -269,4 +386,6 @@ if __name__ == "__main__":
rank
=
0
,
rank
=
0
,
world_size
=
1
,
world_size
=
1
,
mp_group
=
None
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment