Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
49b5b5d6
Commit
49b5b5d6
authored
Mar 28, 2022
by
Rick Ho
Browse files
test for faster gate
parent
6e1fcca1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
9 deletions
+76
-9
tests/test_ddp.py
tests/test_ddp.py
+14
-9
tests/test_faster_gate.py
tests/test_faster_gate.py
+62
-0
No files found.
tests/test_ddp.py
View file @
49b5b5d6
import
json
import
random
import
os
import
sys
from
typing
import
Dict
...
...
@@ -13,30 +14,34 @@ from test_numerical import _test_fmoe_local_ddp
def
_ensure_initialized
():
if
not
dist
.
is_initialized
()
:
if
'RANK'
not
in
os
.
environ
:
os
.
environ
[
"RANK"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_RANK"
,
"0"
)
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
.
get
(
"OMPI_COMM_WORLD_SIZE"
,
"1"
)
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
os
.
environ
[
"RANK"
]
os
.
environ
[
"MASTER_ADDR"
]
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"localhost"
)
os
.
environ
[
"MASTER_PORT"
]
=
os
.
environ
.
get
(
"MASTER_PORT"
,
"12211"
)
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
):
if
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
"No enough GPU"
)
def
_run_distributed
(
func
,
world_size
,
args
:
Dict
,
script
=
__file__
,
env
=
dict
()):
device_count
=
torch
.
cuda
.
device_count
()
if
device_count
<
world_size
:
pytest
.
skip
(
"No enough GPU, only {} found"
.
format
(
device_count
))
import
subprocess
import
os
ps
=
[]
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"36666"
os
.
environ
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
world_size
)
env
[
"MASTER_ADDR"
]
=
"localhost"
env
[
"MASTER_PORT"
]
=
str
(
random
.
randint
(
50000
,
60000
))
env
[
"OMPI_COMM_WORLD_SIZE"
]
=
str
(
world_size
)
for
i
in
range
(
world_size
):
os
.
environ
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
env
[
"OMPI_COMM_WORLD_RANK"
]
=
str
(
i
)
p
=
subprocess
.
Popen
(
[
sys
.
executable
,
script
,
func
,
json
.
dumps
(
args
)],
stdout
=
subprocess
.
PIPE
[
sys
.
executable
,
script
,
func
,
json
.
dumps
(
args
)],
stdout
=
subprocess
.
PIPE
,
env
=
env
)
ps
.
append
(
p
)
...
...
tests/test_faster_gate.py
0 → 100644
View file @
49b5b5d6
import
pytest
import
os
import
sys
import
json
import
math
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.gates.faster_gate
import
FasterGate
from
test_ddp
import
_ensure_initialized
,
_run_distributed
@
pytest
.
mark
.
parametrize
(
"n_process"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"gpu_per_node"
,
[
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"frac"
,
[.
2
])
def
test_faster_gate
(
n_process
,
d_model
,
batch_size
,
n_expert
,
gpu_per_node
,
frac
):
_run_distributed
(
'_test_faster_gate'
,
n_process
,
{
'd_model'
:
d_model
,
'batch_size'
:
batch_size
,
'n_expert'
:
n_expert
,
'gpu_per_node'
:
gpu_per_node
,
'frac'
:
frac
},
script
=
__file__
,
env
=
dict
(
FMOE_TOPO_GPUS_PER_NODE
=
str
(
gpu_per_node
),
FMOE_TOPO_OUTGOING_FRACTION
=
str
(
frac
)
)
)
def
_test_faster_gate
(
d_model
,
batch_size
,
n_expert
,
gpu_per_node
,
frac
):
_ensure_initialized
()
rank
=
dist
.
get_rank
()
node_rank
=
rank
//
gpu_per_node
gate
=
FasterGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
node_rank
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
topk_idx
,
topk_val
=
gate
(
x
)
cnto
=
0
idxs
=
topk_idx
[:,
0
].
cpu
().
view
(
-
1
).
numpy
()
for
v
in
idxs
:
assert
(
v
!=
-
1
)
if
v
//
n_expert
//
gpu_per_node
!=
rank
//
gpu_per_node
:
cnto
+=
1
assert
(
cnto
<=
math
.
ceil
(
batch_size
*
frac
))
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
test_faster_gate
(
8
,
1024
,
16
,
1
,
2
,
.
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment