Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6c68b56b
Commit
6c68b56b
authored
Mar 30, 2022
by
Rick Ho
Browse files
fix backward grad weight bug
parent
a807e2a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
6 deletions
+7
-6
cuda/fastermoe/smart_schedule.cpp
cuda/fastermoe/smart_schedule.cpp
+3
-2
cuda/fmoe_cuda.cpp
cuda/fmoe_cuda.cpp
+1
-1
fmoe/fastermoe/schedule.py
fmoe/fastermoe/schedule.py
+3
-3
No files found.
cuda/fastermoe/smart_schedule.cpp
View file @
6c68b56b
...
...
@@ -9,7 +9,7 @@
long
pipeline_gran
=
-
1
;
torch
::
Tensor
_smart_sch_forward
(
std
::
vector
<
torch
::
Tensor
>
_smart_sch_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
...
...
@@ -33,6 +33,7 @@ torch::Tensor _smart_sch_forward(
const
auto
num_expert
=
local_expert_count
.
size
(
0
)
/
n_workers
;
const
auto
d_model
=
input_buf
.
size
(
1
);
// TODO: maybe empty is faster
auto
global_input_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
auto
global_output_buf
=
input_buf
.
new_zeros
({
global_batch_size
,
d_model
});
...
...
@@ -55,7 +56,7 @@ torch::Tensor _smart_sch_forward(
d_model
,
num_expert
,
rank
,
n_workers
,
pipeline_gran
,
smgr
);
}));
return
output_buf
;
return
{
output_buf
,
global_input_buf
}
;
}
torch
::
Tensor
_smart_sch_backward
(
...
...
cuda/fmoe_cuda.cpp
View file @
6c68b56b
...
...
@@ -58,7 +58,7 @@ std::vector<torch::Tensor> _swipe_once(
long
n_expert
,
long
n_worker
,
long
bias
);
// smart scheduling
torch
::
Tensor
_smart_sch_forward
(
std
::
vector
<
torch
::
Tensor
>
_smart_sch_forward
(
torch
::
Tensor
input_buf
,
torch
::
Tensor
local_expert_count
,
torch
::
Tensor
global_expert_count
,
...
...
fmoe/fastermoe/schedule.py
View file @
6c68b56b
...
...
@@ -36,7 +36,7 @@ class MoEForward(Function):
ctx
.
gobs
[
idx
]
=
y0
y
.
copy_
(
y0
)
local_output_buf
=
fmoe_native
.
smart_sch_forward
(
local_output_buf
,
gib
=
fmoe_native
.
smart_sch_forward
(
local_input_buf
,
local_expert_count
,
global_expert_count
,
stored_models
,
fwd_batch_size
,
...
...
@@ -46,7 +46,7 @@ class MoEForward(Function):
maybe_overlap
=
False
)
variables
=
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
)
stored_models
,
gib
)
ctx
.
moe_args
=
fwd_batch_size
,
inp
.
shape
[
0
],
world_size
ctx
.
save_for_backward
(
*
variables
)
...
...
@@ -56,7 +56,7 @@ class MoEForward(Function):
@
staticmethod
def
backward
(
ctx
,
grad_out
):
(
pos_s
,
pos_g
,
local_expert_count
,
global_expert_count
,
stored_models
)
=
ctx
.
saved_tensors
stored_models
,
_
)
=
ctx
.
saved_tensors
(
fwd_batch_size
,
inp_batch_size
,
world_size
)
=
ctx
.
moe_args
def
_expert_backward
(
grad_y
,
grad_x
,
idx
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment