Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
e01c8b5f
"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "a339f573f09be219c4ac368b441e3580cd89cf57"
Commit
e01c8b5f
authored
May 27, 2022
by
Rick Ho
Browse files
fix gshard gate test
parent
d3f71f21
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
3 deletions
+12
-3
tests/test_gates.py
tests/test_gates.py
+12
-3
No files found.
tests/test_gates.py
View file @
e01c8b5f
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
fmoe.gates
import
GShardGate
,
SwitchGate
from
fmoe.functions
import
ensure_comm
from
test_ddp
import
_ensure_initialized
,
_run_distributed
from
test_ddp
import
_ensure_initialized
,
_run_distributed
...
@@ -33,9 +34,11 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
...
@@ -33,9 +34,11 @@ def test_gshard_gate(d_model, batch_size, n_expert, cap):
def
_test_gshard_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
def
_test_gshard_gate
(
d_model
,
batch_size
,
n_expert
,
cap
):
_ensure_initialized
()
_ensure_initialized
()
rank
=
torch
.
distributed
.
get_rank
()
gate
=
GShardGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
gate
=
GShardGate
(
d_model
,
n_expert
,
dist
.
get_world_size
(),
capacity
=
(
cap
,
cap
)).
cuda
()
capacity
=
(
cap
,
cap
)).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
x
=
torch
.
rand
(
batch_size
,
d_model
).
cuda
()
ensure_comm
(
x
,
None
)
topk_idx
,
topk_val
=
gate
(
x
)
topk_idx
,
topk_val
=
gate
(
x
)
counts
=
[
0
for
_
in
range
(
n_expert
*
dist
.
get_world_size
())]
counts
=
[
0
for
_
in
range
(
n_expert
*
dist
.
get_world_size
())]
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
for
v
in
topk_idx
.
cpu
().
view
(
-
1
).
numpy
():
...
@@ -46,11 +49,17 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
...
@@ -46,11 +49,17 @@ def _test_gshard_gate(d_model, batch_size, n_expert, cap):
assert
(
i
<=
real_cap
)
assert
(
i
<=
real_cap
)
gate_score
=
gate
.
gate
(
x
)
gate_score
=
gate
.
gate
(
x
)
gate_top_k_val
,
gate_top_k_idx
=
torch
.
topk
(
gate_score
,
k
=
gate
.
top_k
,
dim
=-
1
,
largest
=
True
,
sorted
=
False
)
gate_top_k_val
=
gate_top_k_val
.
view
(
-
1
,
gate
.
top_k
)
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
for
j
in
range
(
gate
.
top_k
):
for
j
in
range
(
gate
.
top_k
):
v
=
topk_idx
[
i
,
j
]
v
=
topk_idx
[
i
,
j
]
if
v
!=
-
1
:
if
v
!=
-
1
:
assert
topk_val
[
i
,
j
]
==
gate_score
[
i
,
v
]
assert
topk_val
[
i
,
j
]
==
gate_score
[
i
,
j
]
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"d_model"
,
[
1024
])
...
@@ -109,7 +118,7 @@ if __name__ == '__main__':
...
@@ -109,7 +118,7 @@ if __name__ == '__main__':
args
=
json
.
loads
(
sys
.
argv
[
2
])
args
=
json
.
loads
(
sys
.
argv
[
2
])
locals
()[
sys
.
argv
[
1
]](
**
args
)
locals
()[
sys
.
argv
[
1
]](
**
args
)
else
:
else
:
_ensure_initialized
()
#
_ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2)
# test_gshard_gate(4096, 1024, 4, .2)
test_
switch
_gate
(
8
,
16
,
4
,
.
1
)
_
test_
gshard
_gate
(
8
,
16
,
4
,
.
1
)
# test_switch_gate(4096, 1024, 4, .2)
# test_switch_gate(4096, 1024, 4, .2)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment