Commit 6c68b56b authored by Rick Ho's avatar Rick Ho
Browse files

fix backward grad weight bug

parent a807e2a3
......@@ -9,7 +9,7 @@
long pipeline_gran = -1;
torch::Tensor _smart_sch_forward(
std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
......@@ -33,6 +33,7 @@ torch::Tensor _smart_sch_forward(
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto d_model = input_buf.size(1);
// TODO: maybe empty is faster
auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
......@@ -55,7 +56,7 @@ torch::Tensor _smart_sch_forward(
d_model, num_expert, rank, n_workers,
pipeline_gran, smgr);
}));
return output_buf;
return {output_buf, global_input_buf};
}
torch::Tensor _smart_sch_backward(
......
......@@ -58,7 +58,7 @@ std::vector<torch::Tensor> _swipe_once(
long n_expert, long n_worker, long bias);
// smart scheduling
torch::Tensor _smart_sch_forward(
std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
......
......@@ -36,7 +36,7 @@ class MoEForward(Function):
ctx.gobs[idx] = y0
y.copy_(y0)
local_output_buf = fmoe_native.smart_sch_forward(
local_output_buf, gib = fmoe_native.smart_sch_forward(
local_input_buf,
local_expert_count, global_expert_count,
stored_models, fwd_batch_size,
......@@ -46,7 +46,7 @@ class MoEForward(Function):
maybe_overlap=False)
variables = (pos_s, pos_g, local_expert_count, global_expert_count,
stored_models)
stored_models, gib)
ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
ctx.save_for_backward(*variables)
......@@ -56,7 +56,7 @@ class MoEForward(Function):
@staticmethod
def backward(ctx, grad_out):
(pos_s, pos_g, local_expert_count, global_expert_count,
stored_models) = ctx.saved_tensors
stored_models, _) = ctx.saved_tensors
(fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args
def _expert_backward(grad_y, grad_x, idx):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment