Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
f804a121
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4774fe3afc61b40a56244e9411a7c3e64ae8147f"
Commit
f804a121
authored
May 19, 2021
by
Rick Ho
Browse files
update test
parent
980cf4b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
7 deletions
+8
-7
tests/moe.py
tests/moe.py
+5
-3
tests/test_gates.py
tests/test_gates.py
+2
-2
tests/test_numerical.py
tests/test_numerical.py
+1
-2
No files found.
tests/moe.py
View file @
f804a121
...
@@ -28,11 +28,12 @@ class BruteForceMoELinear(nn.Module):
...
@@ -28,11 +28,12 @@ class BruteForceMoELinear(nn.Module):
self
.
top_k
=
top_k
self
.
top_k
=
top_k
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
gate_long
=
gate_idx
.
long
()
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
gate_long
=
gate_idx
.
long
().
view
(
-
1
)
batch_size
=
inp
.
size
(
0
)
batch_size
=
inp
.
size
(
0
)
o
=
torch
.
empty
(
batch_size
,
self
.
d_model
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
o
=
torch
.
empty
(
batch_size
,
self
.
d_model
,
dtype
=
inp
.
dtype
,
device
=
inp
.
device
)
for
i
in
range
(
self
.
weight_htoh4
.
shape
[
0
]):
for
i
in
range
(
self
.
weight_htoh4
.
shape
[
0
]):
idx
=
gate_
idx
==
i
idx
=
gate_
long
==
i
x
=
inp
[
idx
]
x
=
inp
[
idx
]
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
@
self
.
weight_htoh4
[
i
].
t
()
x
=
x
+
self
.
bias_htoh4
[
i
]
x
=
x
+
self
.
bias_htoh4
[
i
]
...
@@ -56,7 +57,8 @@ class BruteForceMoE(nn.Module):
...
@@ -56,7 +57,8 @@ class BruteForceMoE(nn.Module):
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
*
world_size
)]
self
.
experts
=
[
expert
(
d_model
)
for
_
in
range
(
num_expert
*
world_size
)]
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
def
forward
(
self
,
inp
,
gate_idx
,
gate_score
):
gate_long
=
gate_idx
.
long
()
inp
=
inp
.
repeat_interleave
(
repeats
=
self
.
top_k
,
dim
=
0
)
gate_long
=
gate_idx
.
long
().
view
(
-
1
)
batch_size
=
inp
.
size
(
0
)
batch_size
=
inp
.
size
(
0
)
x
=
inp
.
new_zeros
((
batch_size
,
self
.
d_model
))
x
=
inp
.
new_zeros
((
batch_size
,
self
.
d_model
))
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
...
...
tests/test_gates.py
View file @
f804a121
...
@@ -58,5 +58,5 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
...
@@ -58,5 +58,5 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
_ensure_initialized
()
_ensure_initialized
()
#
test_gshard_gate(4096, 1024, 4, .2)
test_gshard_gate
(
4096
,
1024
,
4
,
.
2
)
test_switch_gate
(
4096
,
1024
,
4
,
.
2
)
#
test_switch_gate(4096, 1024, 4, .2)
tests/test_numerical.py
View file @
f804a121
...
@@ -39,9 +39,8 @@ def _perform_forward(
...
@@ -39,9 +39,8 @@ def _perform_forward(
inp_raw
.
requires_grad
=
True
inp_raw
.
requires_grad
=
True
gate_idx
,
gate_score
=
moe
.
gate
(
inp_raw
)
gate_idx
,
gate_score
=
moe
.
gate
(
inp_raw
)
inp_repeated
=
inp_raw
.
repeat_interleave
(
repeats
=
top_k
,
dim
=
0
)
moe_out
=
moe
(
inp
)
moe_out
=
moe
(
inp
)
raw_out
=
moe_raw
(
inp_r
epeated
,
gate_idx
,
gate_score
)
raw_out
=
moe_raw
(
inp_r
aw
,
gate_idx
,
gate_score
)
raw_out
.
mean
().
backward
()
raw_out
.
mean
().
backward
()
moe_out
.
mean
().
backward
()
moe_out
.
mean
().
backward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment