Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
952e3135
Commit
952e3135
authored
Jan 28, 2021
by
Rick Ho
Browse files
fix scatter/gather bug to make it correct
parent
2d250fbf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
8 deletions
+7
-8
fmoe/moe_function.py
fmoe/moe_function.py
+5
-6
tests/moe_test.py
tests/moe_test.py
+2
-2
No files found.
fmoe/moe_function.py
View file @
952e3135
...
...
@@ -14,7 +14,7 @@ class MOELocal(Function):
# expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
ecc
=
expert_count
.
cpu
()
input_buf
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
inp
,
pos
)
input_buf
,
=
fmoe_cuda
.
local_
g
at
h
er
(
inp
,
pos
)
output_buf
,
=
fmoe_cuda
.
forward
(
input_buf
,
weight
,
ecc
)
output
=
fmoe_cuda
.
local_gather
(
output_buf
,
pos
)
...
...
@@ -52,13 +52,12 @@ class MOEGlobal(Function):
global_expert_count
,
=
fmoe_cuda
.
expert_exchange
(
local_expert_count
,
num_expert
,
world_size
)
print
(
'Local {} Global {}'
.
format
(
local_expert_count
,
global_expert_count
))
fwd_expert_count
=
global_expert_count
.
view
(
num_expert
,
world_size
).
sum
(
dim
=
1
).
cpu
()
fwd_expert_count
=
global_expert_count
.
view
(
world_size
,
num_expert
).
sum
(
dim
=
0
).
cpu
()
fwd_batch_size
=
int
(
fwd_expert_count
.
sum
().
item
())
local_input_buf
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
inp
,
pos
)
local_input_buf
,
=
fmoe_cuda
.
local_
g
at
h
er
(
inp
,
pos
)
local_expert_count
=
local_expert_count
.
cpu
()
global_expert_count
=
global_expert_count
.
cpu
()
...
...
@@ -67,7 +66,7 @@ class MOEGlobal(Function):
local_expert_count
,
global_expert_count
,
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
)
output
,
=
fmoe_cuda
.
local_
g
at
h
er
(
local_output_buf
,
pos
)
output
,
=
fmoe_cuda
.
local_
sc
at
t
er
(
local_output_buf
,
pos
)
variables
=
(
global_input_buf
,
gate
,
weight
,
local_expert_count
,
global_expert_count
,
fwd_expert_count
,
...
...
tests/moe_test.py
View file @
952e3135
...
...
@@ -135,9 +135,9 @@ def test():
print
(
'Rank {} {} abs err {}'
.
format
(
rank
,
name
,
err
))
if
err
>
1e-3
:
sys
.
stderr
.
write
(
'=========== moe out ==============
\n
'
)
sys
.
stderr
.
write
(
'{}'
.
format
(
mo
))
sys
.
stderr
.
write
(
'{}
\n
'
.
format
(
mo
))
sys
.
stderr
.
write
(
'=========== raw out ==============
\n
'
)
sys
.
stderr
.
write
(
'{}'
.
format
(
ro
))
sys
.
stderr
.
write
(
'{}
\n
'
.
format
(
ro
))
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment