Commit bb6ab0fb authored by dongcl's avatar dongcl
Browse files

rewrite combined_1f1b

parent 69add73b
...@@ -361,13 +361,14 @@ def forward_backward_step( ...@@ -361,13 +361,14 @@ def forward_backward_step(
context_manager = contextlib.nullcontext() context_manager = contextlib.nullcontext()
# forward preprocess # forward preprocess
unwrap_output_tensor = False
if f_model is not None: if f_model is not None:
with f_context: with f_context:
if is_first_microbatch and hasattr(f_model, 'set_is_first_microbatch'): if is_first_microbatch and hasattr(f_model, 'set_is_first_microbatch'):
f_model.set_is_first_microbatch() f_model.set_is_first_microbatch()
if current_microbatch is not None: if current_microbatch is not None:
set_current_microbatch(f_model, current_microbatch) set_current_microbatch(f_model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list): if not isinstance(input_tensor, list):
input_tensor = [input_tensor] input_tensor = [input_tensor]
unwrap_output_tensor = True unwrap_output_tensor = True
...@@ -387,10 +388,10 @@ def forward_backward_step( ...@@ -387,10 +388,10 @@ def forward_backward_step(
), "first output of forward_step_func must be one instance of AbstractSchedulePlan" ), "first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess # backward preprocess
unwrap_input_tensor_grad = False
b_schedule_plan = None b_schedule_plan = None
if b_model is not None: if b_model is not None:
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(b_input_tensor, list): if not isinstance(b_input_tensor, list):
b_input_tensor = [b_input_tensor] b_input_tensor = [b_input_tensor]
unwrap_input_tensor_grad = True unwrap_input_tensor_grad = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment