Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
90c4bccf
Commit
90c4bccf
authored
May 19, 2021
by
Rich Ho
Browse files
fix scatter bug across gpus
parent
670d70ed
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
fmoe/functions.py
fmoe/functions.py
+4
-4
No files found.
fmoe/functions.py
View file @
90c4bccf
...
@@ -117,7 +117,7 @@ class MOEScatter(Function):
...
@@ -117,7 +117,7 @@ class MOEScatter(Function):
)
)
else
:
else
:
global_input_buf
=
local_input_buf
global_input_buf
=
local_input_buf
ctx
.
moe_args
=
inp
.
shape
[
0
],
world_size
ctx
.
moe_args
=
inp
.
shape
[
0
],
pos
.
shape
[
0
],
world_size
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
variables
=
(
pos
,
local_expert_count
,
global_expert_count
)
ctx
.
save_for_backward
(
*
variables
)
ctx
.
save_for_backward
(
*
variables
)
return
global_input_buf
return
global_input_buf
...
@@ -125,19 +125,19 @@ class MOEScatter(Function):
...
@@ -125,19 +125,19 @@ class MOEScatter(Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
global_grad_in
):
def
backward
(
ctx
,
global_grad_in
):
(
pos
,
local_expert_count
,
global_expert_count
)
=
ctx
.
saved_tensors
(
pos
,
local_expert_count
,
global_expert_count
)
=
ctx
.
saved_tensors
(
local
_batch_size
,
world_size
)
=
ctx
.
moe_args
(
inp_batch_size
,
buf
_batch_size
,
world_size
)
=
ctx
.
moe_args
if
world_size
>
1
:
if
world_size
>
1
:
(
local_grad_in
,)
=
fmoe_cuda
.
global_gather
(
(
local_grad_in
,)
=
fmoe_cuda
.
global_gather
(
global_grad_in
,
global_grad_in
,
local_expert_count
,
local_expert_count
,
global_expert_count
,
global_expert_count
,
local
_batch_size
,
buf
_batch_size
,
world_size
,
world_size
,
)
)
else
:
else
:
local_grad_in
=
global_grad_in
local_grad_in
=
global_grad_in
grad_in
=
_local_gather
(
local_grad_in
,
pos
,
local
_batch_size
)
grad_in
=
_local_gather
(
local_grad_in
,
pos
,
inp
_batch_size
)
return
grad_in
,
None
,
None
,
None
,
None
,
None
return
grad_in
,
None
,
None
,
None
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment