Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
704092b1
Commit
704092b1
authored
Feb 23, 2021
by
Sengxian
Browse files
Fix input grad in mp group
parent
092c8d67
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
6 deletions
+27
-6
fmoe/functions.py
fmoe/functions.py
+24
-0
fmoe/layers.py
fmoe/layers.py
+3
-6
No files found.
fmoe/functions.py
View file @
704092b1
...
@@ -192,3 +192,27 @@ class AllGather(Function):
...
@@ -192,3 +192,27 @@ class AllGather(Function):
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
rank
,
dim0
=
ctx
.
args
rank
,
dim0
=
ctx
.
args
return
grad_out
[
rank
*
dim0
:(
rank
+
1
)
*
dim0
],
None
,
None
,
None
return
grad_out
[
rank
*
dim0
:(
rank
+
1
)
*
dim0
],
None
,
None
,
None
class
Slice
(
Function
):
r
'''
A wrapper for the Slice function to support auto-differentiation.
'''
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
B
:
int
=
inp
.
shape
[
0
]
local_batch_size
=
B
//
world_size
batch_start
=
local_batch_size
*
rank
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
inp
[
batch_start
:
batch_end
]
ctx
.
args
=
world_size
,
group
return
inp
@
staticmethod
def
backward
(
ctx
,
grad_out
):
world_size
,
group
=
ctx
.
args
tensor_list
=
[
torch
.
empty_like
(
grad_out
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
tensor_list
,
grad_out
,
group
=
group
)
torch
.
cuda
.
synchronize
()
grad_out
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
return
grad_out
,
None
,
None
,
None
fmoe/layers.py
View file @
704092b1
...
@@ -8,7 +8,7 @@ import numpy as np
...
@@ -8,7 +8,7 @@ import numpy as np
from
.functions
import
moe_prepare_forward
from
.functions
import
moe_prepare_forward
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
MOEScatter
,
MOEGather
,
MOELinear
from
.functions
import
AllGather
from
.functions
import
AllGather
,
Slice
from
.gates
import
NaiveGate
from
.gates
import
NaiveGate
...
@@ -179,11 +179,8 @@ class FMoE(nn.Module):
...
@@ -179,11 +179,8 @@ class FMoE(nn.Module):
expert is multiplied to the experts' output tensors as a weight.
expert is multiplied to the experts' output tensors as a weight.
'''
'''
if
self
.
mp_size
>
1
:
if
self
.
mp_size
>
1
:
B
:
int
=
inp
.
shape
[
0
]
inp
=
Slice
.
apply
(
inp
,
local_batch_size
=
B
//
self
.
mp_size
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
batch_start
=
local_batch_size
*
self
.
mp_rank
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
inp
[
batch_start
:
batch_end
]
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
gate_top_k_idx
,
gate_score
=
self
.
gate
(
inp
)
# to: (BxLxtop_k) x d_model
# to: (BxLxtop_k) x d_model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment