Commit 2385a133 authored by dongcl's avatar dongcl
Browse files

modify dualpipev, add foward_step_helper and backward_step_helper

parent cb1230db
......@@ -354,9 +354,9 @@ def generate_dualpipev_schedule(pp_size, num_microbatches):
num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2
num_interleaved_backward_stages[i] = i + 1
num_interleaved_backward_stages[i] = (i + 1) * 2
num_cooldown_stages[i] = [i + 1, pp_size - 2 * i - 2, i + 1]
num_cooldown_stages[i] = [pp_size // 2 - i - 1, pp_size - 2 * i - 2, i + 1]
schedule_all_stages = {
'warmup': num_warmup_stages,
......@@ -537,10 +537,33 @@ def forward_backward_pipelining_with_cutinhalf(
# Disable async grad reductions
no_sync_func = config.no_sync_func
if isinstance(no_sync_func, list):
def multi_no_sync():
stack = contextlib.ExitStack()
for model_chunk_no_sync_func in config.no_sync_func:
stack.enter_context(model_chunk_no_sync_func())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
config.grad_sync_func = [config.grad_sync_func for _ in model]
if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
config.param_sync_func = [config.param_sync_func for _ in model]
# Disable config.grad_sync_func and config.param_sync_func if only running forward passes.
# They will be re-enabled at the end of this function.
grad_sync_func, param_sync_func = None, None
if forward_only:
grad_sync_func, param_sync_func = config.grad_sync_func, config.param_sync_func
config.grad_sync_func, config.param_sync_func = None, None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
......@@ -565,7 +588,7 @@ def forward_backward_pipelining_with_cutinhalf(
):
"""Helper method to run combined forward and backward step"""
# forward prepare
fwd_microbatch_id = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
fwd_microbatch_id = cur_fwd_chunk_microbatch[fwd_model_chunk_id]
f_context = contextlib.nullcontext()
set_dualpipe_chunk(fwd_model_chunk_id)
......@@ -653,11 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
master_chunk_id = 0
slave_chunk_id = 1
master_cur_microbatch = 0
slave_cur_microbatch = num_microbatches
master_microbatch_max = num_microbatches
slave_microbatch_max = num_microbatches * 2
cur_fwd_chunk_microbatch = [0, num_microbatches]
cur_bwd_chunk_microbatch = [0, num_microbatches]
num_chunk_max_microbatch = [num_microbatches, num_microbatches * 2]
checkpoint_activations_microbatch = None
fwd_wait_handles_warmup = None
......@@ -688,33 +709,31 @@ def forward_backward_pipelining_with_cutinhalf(
return output_tensor
def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, is_last_microbatch=False):
# # launch grad synchronization (default)
# if config.grad_sync_func is None and is_last_microbatch:
# enable_grad_sync()
def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, bwd_model_chunk_id=None, bwd_cur_microbatch=None):
nonlocal master_chunk_id
nonlocal slave_chunk_id
nonlocal num_chunk_max_microbatch
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if (
bwd_cur_microbatch is not None
and bwd_cur_microbatch == num_chunk_max_microbatch[bwd_model_chunk_id] - 1
):
if (
config.grad_sync_func is None
or (bwd_model_chunk_id == slave_chunk_id and parallel_state.is_pipeline_last_stage())
or (bwd_model_chunk_id == master_chunk_id and parallel_state.is_pipeline_first_stage())
):
enable_grad_sync()
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
# # launch grad synchronization (custom grad sync)
# # Note: Asynchronous communication tends to slow down compute.
# # To reduce idling from mismatched microbatch times, we launch
# # asynchronous communication at the same time across the
# # pipeline-parallel group.
# if config.grad_sync_func is not None:
# grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
# if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
# grad_sync_virtual_microbatch_id
# ):
# grad_sync_chunk_id = get_model_chunk_id(
# grad_sync_virtual_microbatch_id, forward=False
# )
# enable_grad_sync()
# config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
# synchronized_model_chunks.add(grad_sync_chunk_id)
# disable_grad_sync()
return input_tensor_grad
# Run warmup forward passes
......@@ -724,11 +743,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_warmup = forward_step_helper(
input_tensor,
master_chunk_id,
master_cur_microbatch,
cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch
)
master_cur_microbatch += 1
cur_fwd_chunk_microbatch[master_chunk_id] += 1
if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward(
......@@ -758,10 +776,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor = forward_step_helper(
input_tensor,
master_chunk_id,
master_cur_microbatch,
cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch
)
master_cur_microbatch += 1
cur_fwd_chunk_microbatch[master_chunk_id] += 1
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None:
for req, req_handle in fwd_wait_handles_send.items():
......@@ -810,10 +828,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_slave_chunk = forward_step_helper(
input_tensor_slave,
slave_chunk_id,
slave_cur_microbatch,
cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=is_first_microbatch
)
slave_cur_microbatch += 1
cur_fwd_chunk_microbatch[slave_chunk_id] += 1
if not forward_only:
if i == schedule['interleaved_forward'][rank] - 1:
......@@ -849,10 +867,8 @@ def forward_backward_pipelining_with_cutinhalf(
if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
cur_bwd_chunk_microbatch[slave_chunk_id] += 1
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
......@@ -890,10 +906,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor = forward_step_helper(
input_tensor_slave,
slave_chunk_id,
slave_cur_microbatch,
cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=False
)
slave_cur_microbatch += 1
cur_fwd_chunk_microbatch[slave_chunk_id] += 1
if not forward_only:
output_tensor_grad_bwd, _ = recv_backward(
......@@ -913,9 +929,7 @@ def forward_backward_pipelining_with_cutinhalf(
num_overlap_steps += schedule['interleaved_backward'][rank]
for step_id in range(num_overlap_steps):
only_bwd = False
if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max:
only_bwd = True
if fwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch == slave_microbatch_max:
if cur_fwd_chunk_microbatch[fwd_model_chunk_id] == num_chunk_max_microbatch[fwd_model_chunk_id]:
only_bwd = True
if not only_bwd:
......@@ -928,18 +942,16 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_recv = None
def pp_post_forward(output_tensor):
nonlocal master_cur_microbatch
nonlocal slave_cur_microbatch
nonlocal cur_fwd_chunk_microbatch
nonlocal num_chunk_max_microbatch
nonlocal fwd_wait_handles
nonlocal fwd_wait_handles_slave_chunk
nonlocal firstFB_no_overlp_handle
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch == master_microbatch_max)
fwd_send_only = (cur_fwd_chunk_microbatch[master_chunk_id] == num_chunk_max_microbatch[master_chunk_id])
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
......@@ -1016,13 +1028,13 @@ def forward_backward_pipelining_with_cutinhalf(
# forward
pp_pre_forward()
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
output_tensor = forward_step_helper(
input_tensor,
fwd_model_chunk_id,
fwd_microbatch,
cur_fwd_chunk_microbatch[fwd_model_chunk_id],
is_first_microbatch=False
)
cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1
input_tensor = pp_post_forward(output_tensor)
# backward
......@@ -1031,13 +1043,14 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
else:
input_tensor_grad = None
output_tensor_grad_bwd = pp_post_backward(input_tensor_grad)
# only run backward
else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
if bwd_model_chunk_id == slave_chunk_id and cur_fwd_chunk_microbatch[slave_chunk_id] < num_chunk_max_microbatch[slave_chunk_id]:
input_tensor, fwd_wait_handles_recv = recv_forward(
tensor_shape, config, slave_chunk_id, async_op=True)
if not forward_only:
......@@ -1049,75 +1062,71 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
input_tensor_grad = backward_step_helper(
input_tensor_bwd,
output_tensor_bwd,
output_tensor_grad_bwd,
bwd_model_chunk_id=bwd_model_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[bwd_model_chunk_id]
)
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
if step_id == num_overlap_steps - 1:
bwd_wait_handles = send_backward(
input_tensor_grad,
tensor_shape,
config,
bwd_model_chunk_id,
)
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(
input_tensor_grad,
tensor_shape,
config,
fwd_model_chunk_id
)
# swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
if not forward_only:
# Run cooldown phases
merged_input_tensors = []
merged_output_tensors = []
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
if len(input_tensors[bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[1 - bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size):
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req, req_handle in bwd_wait_handles_recv.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
# Launch any remaining grad reductions.
if config.grad_sync_func is not None:
enable_grad_sync()
config.grad_sync_func(model[slave_chunk_id].parameters())
disable_grad_sync()
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
# Run cooldown phases
if not forward_only:
for i in range(schedule['cooldown'][rank][0]):
output_tensor_grad_bwd, _ = recv_backward(tensor_shape, config, master_chunk_id)
input_tensor_bwd = input_tensors[master_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[master_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(
input_tensor_bwd,
output_tensor_bwd,
output_tensor_grad_bwd,
bwd_model_chunk_id=master_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[master_chunk_id]
)
cur_bwd_chunk_microbatch[master_chunk_id] += 1
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
elif i >= schedule['cooldown'][rank][0] - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
tensor_shape, config, bwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
_ = send_backward(
input_tensor_grad,
tensor_shape,
config,
master_chunk_id,
)
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
# Launch any remaining grad reductions.
if config.grad_sync_func is not None:
enable_grad_sync()
config.grad_sync_func(model[master_chunk_id].parameters())
if config.finalize_model_grads_func is not None and not forward_only:
......@@ -1132,4 +1141,8 @@ def forward_backward_pipelining_with_cutinhalf(
model, total_num_tokens if config.calculate_per_token_loss else None
)
# Restore config.grad_sync_func and config.param_sync_func.
if forward_only:
config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func
return forward_data_store
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment