Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
06e75b3a
Commit
06e75b3a
authored
Feb 08, 2021
by
Sengxian
Browse files
Seperate ddp test from test_numerical
parent
1a72a0cb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
117 additions
and
81 deletions
+117
-81
tests/test_ddp.py
tests/test_ddp.py
+84
-0
tests/test_numerical.py
tests/test_numerical.py
+33
-81
No files found.
tests/test_ddp.py
0 → 100644
View file @
06e75b3a
import
json
import
os
import
sys
from
typing
import
Dict
import
pytest
import
torch
from
test_numerical
import
test_fmoe
as
_test_fmoe
from
test_numerical
import
test_fmoe_linear
as
_test_fmoe_linear
def
_run_distributed
(
func
,
args
:
Dict
):
import
subprocess
import
os
ps
,
n
=
[],
2
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"36666"
os
.
environ
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
n
)
for
i
in
range
(
n
):
os
.
environ
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
p
=
subprocess
.
Popen
(
[
sys
.
executable
,
__file__
,
func
,
json
.
dumps
(
args
)],
stdout
=
subprocess
.
PIPE
,
)
ps
.
append
(
p
)
for
p
in
ps
:
p
.
wait
()
retc
=
p
.
poll
()
assert
retc
==
0
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
def
test_fmoe_linear_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
):
_run_distributed
(
"_test_fmoe_linear"
,
{
"num_expert"
:
num_expert
,
"top_k"
:
top_k
,
"batch_size"
:
batch_size
,
"d_model"
:
d_model
,
"d_hidden"
:
d_hidden
,
},
)
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
"NaiveExpert"
,
"LinearExpert"
])
def
test_fmoe_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
expert
,
):
_run_distributed
(
"_test_fmoe"
,
{
"num_expert"
:
num_expert
,
"top_k"
:
top_k
,
"batch_size"
:
batch_size
,
"d_model"
:
d_model
,
"expert"
:
expert
,
},
)
if
__name__
==
"__main__"
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
args
[
"rank"
]
=
torch
.
distributed
.
get_rank
()
args
[
"world_size"
]
=
torch
.
distributed
.
get_world_size
()
locals
()[
sys
.
argv
[
1
]](
**
args
)
tests/test_numerical.py
View file @
06e75b3a
...
@@ -12,9 +12,6 @@ from fmoe.layers import FMoE
...
@@ -12,9 +12,6 @@ from fmoe.layers import FMoE
from
fmoe.transformer
import
_Expert
from
fmoe.transformer
import
_Expert
from
moe
import
BruteForceMoELinear
,
BruteForceMoE
,
NaiveExpert
,
LinearExpert
from
moe
import
BruteForceMoELinear
,
BruteForceMoE
,
NaiveExpert
,
LinearExpert
rank
=
0
world_size
=
1
def
_perform_forward
(
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
):
def
_perform_forward
(
moe
:
nn
.
Module
,
moe_raw
:
nn
.
Module
,
batch_size
,
d_model
,
top_k
):
moe
.
zero_grad
()
moe
.
zero_grad
()
...
@@ -31,7 +28,7 @@ def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, to
...
@@ -31,7 +28,7 @@ def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, to
return
moe_out
,
raw_out
return
moe_out
,
raw_out
def
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
):
def
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out_list
,
raw_out_list
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
print
(
"Rank {} {} abs err {}"
.
format
(
rank
,
name
,
err
))
...
@@ -48,12 +45,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list):
...
@@ -48,12 +45,16 @@ def _assert_numercial(names, moe_out_list, raw_out_list):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
def
test_fmoe_linear
(
def
test_fmoe_linear
(
num_expert
,
num_expert
,
top_k
,
top_k
,
batch_size
,
batch_size
,
d_model
,
d_model
,
d_hidden
,
d_hidden
,
rank
,
world_size
,
activation
=
torch
.
nn
.
functional
.
gelu
,
activation
=
torch
.
nn
.
functional
.
gelu
,
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
...
@@ -113,7 +114,7 @@ def test_fmoe_linear(
...
@@ -113,7 +114,7 @@ def test_fmoe_linear(
raw_out_list
=
_
,
htoh4_grad
,
h4toh_grad
raw_out_list
=
_
,
htoh4_grad
,
h4toh_grad
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
]
names
=
[
"output"
,
"htoh4 weight grad"
,
"h4toh weight grad"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
...
@@ -121,8 +122,16 @@ def test_fmoe_linear(
...
@@ -121,8 +122,16 @@ def test_fmoe_linear(
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
"NaiveExpert"
,
"LinearExpert"
])
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
"NaiveExpert"
,
"LinearExpert"
])
@
pytest
.
mark
.
parametrize
(
"rank"
,
[
0
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
1
])
def
test_fmoe
(
def
test_fmoe
(
batch_size
,
num_expert
,
d_model
,
top_k
,
expert
:
Union
[
Type
[
nn
.
Module
],
str
]
batch_size
,
num_expert
,
d_model
,
top_k
,
expert
:
Union
[
Type
[
nn
.
Module
],
str
],
rank
,
world_size
,
):
):
torch
.
manual_seed
(
42
+
rank
)
torch
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
torch
.
cuda
.
manual_seed
(
42
+
rank
)
...
@@ -197,82 +206,25 @@ def test_fmoe(
...
@@ -197,82 +206,25 @@ def test_fmoe(
raw_out_list
=
[
raw_out
,
raw_grad
]
raw_out_list
=
[
raw_out
,
raw_grad
]
names
=
[
"forward"
,
"backward"
]
names
=
[
"forward"
,
"backward"
]
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
)
_assert_numercial
(
names
,
moe_out_list
,
raw_out_list
,
rank
)
def
_run_distributed
(
func
:
Callable
,
args
:
Dict
):
import
subprocess
import
os
ps
,
n
=
[],
2
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"36666"
os
.
environ
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
n
)
for
i
in
range
(
n
):
if
__name__
==
"__main__"
:
os
.
environ
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
test_fmoe_linear
(
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
i
)
batch_size
=
4
,
p
=
subprocess
.
Popen
(
num_expert
=
4
,
[
sys
.
executable
,
__file__
,
func
.
__name__
,
json
.
dumps
(
args
)],
d_model
=
8
,
stdout
=
subprocess
.
PIPE
,
top_k
=
2
,
)
d_hidden
=
16
,
ps
.
append
(
p
)
rank
=
0
,
world_size
=
1
,
for
p
in
ps
:
p
.
wait
()
retc
=
p
.
poll
()
assert
retc
==
0
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"d_hidden"
,
[
32
])
def
test_fmoe_linear_distributed
(
num_expert
,
top_k
,
batch_size
,
d_model
,
d_hidden
,
):
_run_distributed
(
test_fmoe_linear
,
{
"num_expert"
:
num_expert
,
"top_k"
:
top_k
,
"batch_size"
:
batch_size
,
"d_model"
:
d_model
,
"d_hidden"
:
d_hidden
,
},
)
)
test_fmoe
(
batch_size
=
4
,
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
4
,
8
])
num_expert
=
4
,
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
2
])
d_model
=
8
,
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
top_k
=
2
,
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
16
])
expert
=
NaiveExpert
,
@
pytest
.
mark
.
parametrize
(
"expert"
,
[
"NaiveExpert"
,
"LinearExpert"
])
rank
=
0
,
def
test_fmoe_distributed
(
world_size
=
1
,
num_expert
,
top_k
,
batch_size
,
d_model
,
expert
,
):
_run_distributed
(
test_fmoe
,
{
"num_expert"
:
num_expert
,
"top_k"
:
top_k
,
"batch_size"
:
batch_size
,
"d_model"
:
d_model
,
"expert"
:
expert
,
},
)
)
if
__name__
==
"__main__"
:
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
if
int
(
os
.
environ
[
"WORLD_SIZE"
])
>
1
:
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
rank
=
torch
.
distributed
.
get_rank
()
world_size
=
torch
.
distributed
.
get_world_size
()
if
len
(
sys
.
argv
)
>=
3
:
locals
()[
sys
.
argv
[
1
]](
**
json
.
loads
(
sys
.
argv
[
2
]))
else
:
test_fmoe_linear
(
batch_size
=
4
,
num_expert
=
4
,
d_model
=
8
,
top_k
=
2
,
d_hidden
=
16
)
test_fmoe
(
batch_size
=
4
,
num_expert
=
4
,
d_model
=
8
,
top_k
=
2
,
expert
=
NaiveExpert
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment