Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
76327544
Commit
76327544
authored
Jan 11, 2021
by
Rick Ho
Browse files
distributed test weight
parent
864a4522
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
7 deletions
+11
-7
pytorch/cuda/moe_test.py
pytorch/cuda/moe_test.py
+11
-7
No files found.
pytorch/cuda/moe_test.py
View file @
76327544
...
@@ -67,15 +67,15 @@ def test_module(moe, linear, inp, gate):
...
@@ -67,15 +67,15 @@ def test_module(moe, linear, inp, gate):
moe
.
zero_grad
()
moe
.
zero_grad
()
x
=
(
linear
(
inp
))
x
=
(
linear
(
inp
))
output
=
moe
(
x
,
gate
)
output
=
moe
(
x
,
gate
)
print
(
'ooutput'
,
torch
.
distributed
.
get_rank
(),
output
)
#
print('ooutput', torch.distributed.get_rank(), output)
y
=
output
.
mean
()
y
=
output
.
mean
()
y
.
backward
()
y
.
backward
()
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
return
output
,
moe
.
weight
.
grad
,
linear
.
weight
.
grad
,
linear
.
bias
.
grad
def
test
():
def
test
():
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
+
torch
.
distributed
.
get_rank
()
)
torch
.
cuda
.
manual_seed
(
42
)
torch
.
cuda
.
manual_seed
(
42
+
torch
.
distributed
.
get_rank
()
)
batch_size
=
4
batch_size
=
4
num_expert
=
2
num_expert
=
2
in_feat
=
6
in_feat
=
6
...
@@ -106,10 +106,14 @@ def test():
...
@@ -106,10 +106,14 @@ def test():
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
moe_out
=
test_module
(
moe
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
raw_out
=
test_module
(
moe_raw
,
linear
,
inp
.
clone
(),
gate
.
clone
())
if
world_size
==
1
:
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
names
=
[
'Out'
,
'Moe wei'
,
'Linear wei'
,
'Linear bias'
]
if
world_size
>
1
:
else
:
rank
=
torch
.
distributed
.
get_rank
()
names
=
[
'Out'
]
ou
,
wg
,
lwg
,
lbg
=
raw_out
wg
=
wg
.
cpu
()
torch
.
distributed
.
all_reduce
(
wg
)
wg
=
wg
[
rank
*
num_expert
:(
rank
+
1
)
*
num_expert
]
raw_out
=
ou
,
wg
.
cuda
(),
lwg
,
lbg
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
for
name
,
mo
,
ro
in
zip
(
names
,
moe_out
,
raw_out
):
err
=
(
mo
-
ro
).
abs
().
sum
()
err
=
(
mo
-
ro
).
abs
().
sum
()
print
(
'{} abs err {}'
.
format
(
name
,
err
))
print
(
'{} abs err {}'
.
format
(
name
,
err
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment