Commit 97929735 authored by dongcl's avatar dongcl
Browse files

reorder pipeline

parent 1d497357
...@@ -632,6 +632,10 @@ def schedule_layer_1f1b( ...@@ -632,6 +632,10 @@ def schedule_layer_1f1b(
b_grad = pre_backward() b_grad = pre_backward()
del pre_backward del pre_backward
if f_layer is not None:
with f_context:
f_input = f_layer.attn.forward(f_input)
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
routed_expert_output_grad, shared_expert_output_grad = b_layer.combine.backward(b_grad) routed_expert_output_grad, shared_expert_output_grad = b_layer.combine.backward(b_grad)
...@@ -640,10 +644,6 @@ def schedule_layer_1f1b( ...@@ -640,10 +644,6 @@ def schedule_layer_1f1b(
pre_backward_dw() pre_backward_dw()
del pre_backward_dw del pre_backward_dw
if f_layer is not None:
with f_context:
f_input = f_layer.attn.forward(f_input)
f_dispatch_b_mlp_sync_event = None f_dispatch_b_mlp_sync_event = None
if f_layer is not None and b_layer is not None: if f_layer is not None and b_layer is not None:
f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT f_dispatch_b_mlp_sync_event = F_DISPATCH_B_MLP_SYNC_EVENT
...@@ -653,13 +653,8 @@ def schedule_layer_1f1b( ...@@ -653,13 +653,8 @@ def schedule_layer_1f1b(
shared_expert_output = f_layer.shared_expert.forward() shared_expert_output = f_layer.shared_expert.forward()
f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event) f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
# if f_layer is not None:
# with f_context:
# f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if b_layer is not None: if b_layer is not None:
with b_context: with b_context:
# routed_expert_output_grad, shared_expert_output_grad = b_grad
b_grad = b_layer.routed_expert.backward(routed_expert_output_grad, stream_wait_event=f_dispatch_b_mlp_sync_event) b_grad = b_layer.routed_expert.backward(routed_expert_output_grad, stream_wait_event=f_dispatch_b_mlp_sync_event)
b_layer.shared_expert.backward(shared_expert_output_grad) b_layer.shared_expert.backward(shared_expert_output_grad)
b_grad = b_layer.dispatch.backward(b_grad) b_grad = b_layer.dispatch.backward(b_grad)
...@@ -669,13 +664,6 @@ def schedule_layer_1f1b( ...@@ -669,13 +664,6 @@ def schedule_layer_1f1b(
with f_context: with f_context:
f_input = f_layer.routed_expert.forward(f_input) f_input = f_layer.routed_expert.forward(f_input)
# if b_layer is not None:
# with b_context:
# # b_grad = b_layer.dispatch.backward(b_grad)
# b_layer.shared_expert.backward(shared_expert_output_grad)
# b_layer.routed_expert.dw()
def next_iter_pre_forward(): def next_iter_pre_forward():
if f_layer is not None: if f_layer is not None:
with f_context: with f_context:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment