Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
2ba58797
Commit
2ba58797
authored
Dec 30, 2020
by
Jiezhong Qiu
Browse files
update
parent
560d3f1b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
10 deletions
+19
-10
pytorch/cuda/moe.py
pytorch/cuda/moe.py
+19
-10
No files found.
pytorch/cuda/moe.py
View file @
2ba58797
...
@@ -80,6 +80,8 @@ def test():
...
@@ -80,6 +80,8 @@ def test():
in_feat
=
2
in_feat
=
2
out_feat
=
3
out_feat
=
3
linear
=
nn
.
Linear
(
in_feat
,
in_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe
=
MOELayer
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
=
MOELayer_raw
(
num_expert
,
in_feat
,
out_feat
).
cuda
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
moe_raw
.
weight
.
data
=
moe
.
weight
.
data
.
clone
()
...
@@ -87,21 +89,28 @@ def test():
...
@@ -87,21 +89,28 @@ def test():
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
inp
=
torch
.
rand
(
batch_size
,
in_feat
).
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
gate
=
torch
.
randint
(
low
=
0
,
high
=
num_expert
,
size
=
(
batch_size
,
),
requires_grad
=
False
).
int
().
cuda
()
output
=
moe
(
inp
,
gate
)
linear
.
zero_grad
()
output_raw
=
moe_raw
(
inp
.
clone
(),
gate
.
clone
())
moe
.
zero_grad
()
x
=
linear
(
inp
)
print
(
output
)
output
=
moe
(
x
,
gate
)
print
(
output_raw
)
print
(
"moe output"
,
output
)
y
=
output
.
mean
()
y
=
output
.
mean
()
y
.
backward
()
y
.
backward
()
print
(
"moe.weight.grad"
,
moe
.
weight
.
grad
)
print
(
"linear.weight.grad"
,
linear
.
weight
.
grad
)
print
(
"linear.bias.grad"
,
linear
.
bias
.
grad
)
linear
.
zero_grad
()
moe
.
zero_grad
()
x
=
linear
(
inp
.
clone
())
output_raw
=
moe_raw
(
x
,
gate
.
clone
())
print
(
"moe_raw output"
,
output_raw
)
y_raw
=
output_raw
.
mean
()
y_raw
=
output_raw
.
mean
()
y_raw
.
backward
()
y_raw
.
backward
()
print
(
"moe_raw.weight.grad"
,
moe_raw
.
weight
.
grad
)
print
(
moe
.
weight
.
grad
)
print
(
"linear_raw.weight.grad"
,
linear
.
weight
.
grad
)
print
(
moe_raw
.
weight
.
grad
)
print
(
"linear_raw.bias.grad"
,
linear
.
bias
.
grad
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test
()
test
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment