Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
26cc37cb
Unverified
Commit
26cc37cb
authored
May 24, 2021
by
GODVIX
Committed by
GitHub
May 24, 2021
Browse files
Update test_gates.py
parent
ddaac5eb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
1 deletion
+28
-1
tests/test_gates.py
tests/test_gates.py
+28
-1
No files found.
tests/test_gates.py
View file @
26cc37cb
...
@@ -7,6 +7,7 @@ import math
...
@@ -7,6 +7,7 @@ import math
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
test_ddp
import
_run_distributed
from
test_ddp
import
_run_distributed
...
@@ -54,6 +55,13 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
...
@@ -54,6 +55,13 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
for
i
in
counts
:
for
i
in
counts
:
assert
(
i
<=
real_cap
)
assert
(
i
<=
real_cap
)
gate_score
=
gate
.
gate
(
x
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
gate
.
top_k
):
v
=
topk_idx
[
i
,
j
]
if
v
!=
-
1
:
assert
topk_val
[
i
,
j
]
==
gate_score
[
i
,
v
]
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4096
])
...
@@ -77,6 +85,7 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap):
...
@@ -77,6 +85,7 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap):
gate
=
SwitchGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
gate
=
SwitchGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
rng
=
torch
.
cuda
.
get_rng_state
()
# save rng state
topk_idx
,
topk_val
=
gate
(
x
)
topk_idx
,
topk_val
=
gate
(
x
)
counts
=
[
0
for
_
in
range
(
n_expert
*
dist
.
get_world_size
())]
counts
=
[
0
for
_
in
range
(
n_expert
*
dist
.
get_world_size
())]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
...
@@ -86,6 +95,24 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap):
...
@@ -86,6 +95,24 @@ def _test_switch_gate(d_model, batch_size, n_expert, cap):
for
i
in
counts
:
for
i
in
counts
:
assert
(
i
<=
real_cap
)
assert
(
i
<=
real_cap
)
score
=
gate
.
gate
(
x
)
if
gate
.
training
:
# reset rng state to make sure noise is the same as in gate.forward()
torch
.
cuda
.
set_rng_state
(
rng
)
# random uniform number from [1-eps, 1+eps]
noise
=
torch
.
rand_like
(
score
)
noise
=
noise
*
2
*
gate
.
switch_eps
+
1.0
-
gate
.
switch_eps
score
+=
noise
# fp32 softmax for numerical stability
score
=
F
.
softmax
(
score
.
float
(),
dim
=-
1
)
for
i
in
range
(
batch_size
):
v
=
topk_idx
[
i
]
if
v
!=
-
1
:
assert
topk_val
[
i
]
==
score
[
i
,
topk_idx
[
i
]]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
>=
3
:
if
len
(
sys
.
argv
)
>=
3
:
...
@@ -94,5 +121,5 @@ if __name__ == '__main__':
...
@@ -94,5 +121,5 @@ if __name__ == '__main__':
else
:
else
:
_ensure_initialized
()
_ensure_initialized
()
# test_gshard_gate(4096, 1024, 4, .2)
# test_gshard_gate(4096, 1024, 4, .2)
test_
gshard
_gate
(
8
,
16
,
1
,
.
1
)
test_
switch
_gate
(
8
,
16
,
4
,
.
1
)
# test_switch_gate(4096, 1024, 4, .2)
# test_switch_gate(4096, 1024, 4, .2)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment