Commit 97929735 authored by dongcl's avatar dongcl
Browse files

reorder pipeline

parent 1d497357
......@@ -632,6 +632,10 @@ def schedule_layer_1f1b(
b_grad = pre_backward()
del pre_backward
if f_layer is not None:
with f_context:
f_input = f_layer.attn.forward(f_input)
if b_layer is not None:
with b_context:
routed_expert_output_grad, shared_expert_output_grad = b_layer.combine.backward(b_grad)
......@@ -640,10 +644,6 @@ def schedule_layer_1f1b(
pre_backward_dw()
del pre_backward_dw
if f_layer is not None:
with f_context:
f_input = f_layer.attn.forward(f_input)
f_dispatch_b_mlp_sync_event = None
if f_layer is not None and b_layer is not None:
f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT
......@@ -653,13 +653,8 @@ def schedule_layer_1f1b(
shared_expert_output = f_layer.shared_expert.forward()
f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
# if f_layer is not None:
# with f_context:
# f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None:
with b_context:
# routed_expert_output_grad, shared_expert_output_grad = b_grad
b_grad = b_layer.routed_expert.backward(routed_expert_output_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_layer.shared_expert.backward(shared_expert_output_grad)
b_grad = b_layer.dispatch.backward(b_grad)
......@@ -669,13 +664,6 @@ def schedule_layer_1f1b(
with f_context:
f_input = f_layer.routed_expert.forward(f_input)
# if b_layer is not None:
# with b_context:
# # b_grad = b_layer.dispatch.backward(b_grad)
# b_layer.shared_expert.backward(shared_expert_output_grad)
# b_layer.routed_expert.dw()
def next_iter_pre_forward():
if f_layer is not None:
with f_context:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment