Commit 2385a133 authored by dongcl's avatar dongcl
Browse files

modify dualpipev, add foward_step_helper and backward_step_helper

parent cb1230db
...@@ -354,9 +354,9 @@ def generate_dualpipev_schedule(pp_size, num_microbatches): ...@@ -354,9 +354,9 @@ def generate_dualpipev_schedule(pp_size, num_microbatches):
num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2 num_1b1overlap_stages[i] = (pp_size // 2 - i - 1) * 2
num_interleaved_backward_stages[i] = i + 1 num_interleaved_backward_stages[i] = (i + 1) * 2
num_cooldown_stages[i] = [i + 1, pp_size - 2 * i - 2, i + 1] num_cooldown_stages[i] = [pp_size // 2 - i - 1, pp_size - 2 * i - 2, i + 1]
schedule_all_stages = { schedule_all_stages = {
'warmup': num_warmup_stages, 'warmup': num_warmup_stages,
...@@ -537,10 +537,33 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -537,10 +537,33 @@ def forward_backward_pipelining_with_cutinhalf(
# Disable async grad reductions # Disable async grad reductions
no_sync_func = config.no_sync_func no_sync_func = config.no_sync_func
if isinstance(no_sync_func, list):
def multi_no_sync():
stack = contextlib.ExitStack()
for model_chunk_no_sync_func in config.no_sync_func:
stack.enter_context(model_chunk_no_sync_func())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None: if no_sync_func is None:
no_sync_func = contextlib.nullcontext no_sync_func = contextlib.nullcontext
no_sync_context = None no_sync_context = None
if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
config.grad_sync_func = [config.grad_sync_func for _ in model]
if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
config.param_sync_func = [config.param_sync_func for _ in model]
# Disable config.grad_sync_func and config.param_sync_func if only running forward passes.
# They will be re-enabled at the end of this function.
grad_sync_func, param_sync_func = None, None
if forward_only:
grad_sync_func, param_sync_func = config.grad_sync_func, config.param_sync_func
config.grad_sync_func, config.param_sync_func = None, None
def disable_grad_sync(): def disable_grad_sync():
"""Disable asynchronous grad reductions""" """Disable asynchronous grad reductions"""
nonlocal no_sync_context nonlocal no_sync_context
...@@ -565,7 +588,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -565,7 +588,7 @@ def forward_backward_pipelining_with_cutinhalf(
): ):
"""Helper method to run combined forward and backward step""" """Helper method to run combined forward and backward step"""
# forward prepare # forward prepare
fwd_microbatch_id = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch fwd_microbatch_id = cur_fwd_chunk_microbatch[fwd_model_chunk_id]
f_context = contextlib.nullcontext() f_context = contextlib.nullcontext()
set_dualpipe_chunk(fwd_model_chunk_id) set_dualpipe_chunk(fwd_model_chunk_id)
...@@ -653,11 +676,9 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -653,11 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
master_chunk_id = 0 master_chunk_id = 0
slave_chunk_id = 1 slave_chunk_id = 1
cur_fwd_chunk_microbatch = [0, num_microbatches]
master_cur_microbatch = 0 cur_bwd_chunk_microbatch = [0, num_microbatches]
slave_cur_microbatch = num_microbatches num_chunk_max_microbatch = [num_microbatches, num_microbatches * 2]
master_microbatch_max = num_microbatches
slave_microbatch_max = num_microbatches * 2
checkpoint_activations_microbatch = None checkpoint_activations_microbatch = None
fwd_wait_handles_warmup = None fwd_wait_handles_warmup = None
...@@ -688,33 +709,31 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -688,33 +709,31 @@ def forward_backward_pipelining_with_cutinhalf(
return output_tensor return output_tensor
def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, is_last_microbatch=False): def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, bwd_model_chunk_id=None, bwd_cur_microbatch=None):
# # launch grad synchronization (default) nonlocal master_chunk_id
# if config.grad_sync_func is None and is_last_microbatch: nonlocal slave_chunk_id
# enable_grad_sync() nonlocal num_chunk_max_microbatch
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if (
bwd_cur_microbatch is not None
and bwd_cur_microbatch == num_chunk_max_microbatch[bwd_model_chunk_id] - 1
):
if (
config.grad_sync_func is None
or (bwd_model_chunk_id == slave_chunk_id and parallel_state.is_pipeline_last_stage())
or (bwd_model_chunk_id == master_chunk_id and parallel_state.is_pipeline_first_stage())
):
enable_grad_sync()
input_tensor_grad = backward_step( input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config input_tensor, output_tensor, output_tensor_grad, model_type, config
) )
# # launch grad synchronization (custom grad sync)
# # Note: Asynchronous communication tends to slow down compute.
# # To reduce idling from mismatched microbatch times, we launch
# # asynchronous communication at the same time across the
# # pipeline-parallel group.
# if config.grad_sync_func is not None:
# grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
# if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
# grad_sync_virtual_microbatch_id
# ):
# grad_sync_chunk_id = get_model_chunk_id(
# grad_sync_virtual_microbatch_id, forward=False
# )
# enable_grad_sync()
# config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
# synchronized_model_chunks.add(grad_sync_chunk_id)
# disable_grad_sync()
return input_tensor_grad return input_tensor_grad
# Run warmup forward passes # Run warmup forward passes
...@@ -724,11 +743,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -724,11 +743,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_warmup = forward_step_helper( output_tensor_warmup = forward_step_helper(
input_tensor, input_tensor,
master_chunk_id, master_chunk_id,
master_cur_microbatch, cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[master_chunk_id] += 1
master_cur_microbatch += 1
if i != schedule['warmup'][rank] - 1: if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward( input_tensor, _ = send_forward_recv_forward(
...@@ -758,10 +776,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -758,10 +776,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor = forward_step_helper( output_tensor = forward_step_helper(
input_tensor, input_tensor,
master_chunk_id, master_chunk_id,
master_cur_microbatch, cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch is_first_microbatch=is_first_microbatch
) )
master_cur_microbatch += 1 cur_fwd_chunk_microbatch[master_chunk_id] += 1
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None: if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None:
for req, req_handle in fwd_wait_handles_send.items(): for req, req_handle in fwd_wait_handles_send.items():
...@@ -810,10 +828,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -810,10 +828,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_slave_chunk = forward_step_helper( output_tensor_slave_chunk = forward_step_helper(
input_tensor_slave, input_tensor_slave,
slave_chunk_id, slave_chunk_id,
slave_cur_microbatch, cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=is_first_microbatch is_first_microbatch=is_first_microbatch
) )
slave_cur_microbatch += 1 cur_fwd_chunk_microbatch[slave_chunk_id] += 1
if not forward_only: if not forward_only:
if i == schedule['interleaved_forward'][rank] - 1: if i == schedule['interleaved_forward'][rank] - 1:
...@@ -849,10 +867,8 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -849,10 +867,8 @@ def forward_backward_pipelining_with_cutinhalf(
if not forward_only: if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1] input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0) output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
input_tensor_grad = backward_step( cur_bwd_chunk_microbatch[slave_chunk_id] += 1
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: for req in fwd_wait_handles_slave_chunk:
...@@ -890,10 +906,10 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -890,10 +906,10 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor = forward_step_helper( output_tensor = forward_step_helper(
input_tensor_slave, input_tensor_slave,
slave_chunk_id, slave_chunk_id,
slave_cur_microbatch, cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=False is_first_microbatch=False
) )
slave_cur_microbatch += 1 cur_fwd_chunk_microbatch[slave_chunk_id] += 1
if not forward_only: if not forward_only:
output_tensor_grad_bwd, _ = recv_backward( output_tensor_grad_bwd, _ = recv_backward(
...@@ -913,9 +929,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -913,9 +929,7 @@ def forward_backward_pipelining_with_cutinhalf(
num_overlap_steps += schedule['interleaved_backward'][rank] num_overlap_steps += schedule['interleaved_backward'][rank]
for step_id in range(num_overlap_steps): for step_id in range(num_overlap_steps):
only_bwd = False only_bwd = False
if fwd_model_chunk_id == master_chunk_id and master_cur_microbatch == master_microbatch_max: if cur_fwd_chunk_microbatch[fwd_model_chunk_id] == num_chunk_max_microbatch[fwd_model_chunk_id]:
only_bwd = True
if fwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch == slave_microbatch_max:
only_bwd = True only_bwd = True
if not only_bwd: if not only_bwd:
...@@ -928,18 +942,16 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -928,18 +942,16 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles_recv = None fwd_wait_handles_recv = None
def pp_post_forward(output_tensor): def pp_post_forward(output_tensor):
nonlocal master_cur_microbatch nonlocal cur_fwd_chunk_microbatch
nonlocal slave_cur_microbatch nonlocal num_chunk_max_microbatch
nonlocal fwd_wait_handles nonlocal fwd_wait_handles
nonlocal fwd_wait_handles_slave_chunk nonlocal fwd_wait_handles_slave_chunk
nonlocal firstFB_no_overlp_handle nonlocal firstFB_no_overlp_handle
if fwd_model_chunk_id == master_chunk_id: if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False fwd_send_only = False
else: else:
slave_cur_microbatch += 1 fwd_send_only = (cur_fwd_chunk_microbatch[master_chunk_id] == num_chunk_max_microbatch[master_chunk_id])
fwd_send_only = (master_cur_microbatch == master_microbatch_max)
# 同步上个阶段最后一个slave前向send # 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
...@@ -1016,13 +1028,13 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1016,13 +1028,13 @@ def forward_backward_pipelining_with_cutinhalf(
# forward # forward
pp_pre_forward() pp_pre_forward()
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
output_tensor = forward_step_helper( output_tensor = forward_step_helper(
input_tensor, input_tensor,
fwd_model_chunk_id, fwd_model_chunk_id,
fwd_microbatch, cur_fwd_chunk_microbatch[fwd_model_chunk_id],
is_first_microbatch=False is_first_microbatch=False
) )
cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1
input_tensor = pp_post_forward(output_tensor) input_tensor = pp_post_forward(output_tensor)
# backward # backward
...@@ -1031,13 +1043,14 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1031,13 +1043,14 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1] input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0) output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd) input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
else: else:
input_tensor_grad = None input_tensor_grad = None
output_tensor_grad_bwd = pp_post_backward(input_tensor_grad) output_tensor_grad_bwd = pp_post_backward(input_tensor_grad)
# only run backward # only run backward
else: else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max: if bwd_model_chunk_id == slave_chunk_id and cur_fwd_chunk_microbatch[slave_chunk_id] < num_chunk_max_microbatch[slave_chunk_id]:
input_tensor, fwd_wait_handles_recv = recv_forward( input_tensor, fwd_wait_handles_recv = recv_forward(
tensor_shape, config, slave_chunk_id, async_op=True) tensor_shape, config, slave_chunk_id, async_op=True)
if not forward_only: if not forward_only:
...@@ -1049,75 +1062,71 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1049,75 +1062,71 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1] input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0) output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd) input_tensor_grad = backward_step_helper(
input_tensor_bwd,
output_tensor_bwd,
output_tensor_grad_bwd,
bwd_model_chunk_id=bwd_model_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[bwd_model_chunk_id]
)
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad output_tensor_grad_bwd = input_tensor_grad
else:
if step_id == num_overlap_steps - 1:
bwd_wait_handles = send_backward(
input_tensor_grad,
tensor_shape,
config,
bwd_model_chunk_id,
)
else: else:
# send_backward_recv_slave_backward # send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad, output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(
tensor_shape, config, fwd_model_chunk_id) input_tensor_grad,
tensor_shape,
config,
fwd_model_chunk_id
)
# swap fwd & bwd chunks # swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
if not forward_only: # Launch any remaining grad reductions.
# Run cooldown phases if config.grad_sync_func is not None:
merged_input_tensors = [] enable_grad_sync()
merged_output_tensors = [] config.grad_sync_func(model[slave_chunk_id].parameters())
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0: disable_grad_sync()
if len(input_tensors[bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[1 - bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size):
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req, req_handle in bwd_wait_handles_recv.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd) # Run cooldown phases
if not forward_only:
for i in range(schedule['cooldown'][rank][0]):
output_tensor_grad_bwd, _ = recv_backward(tensor_shape, config, master_chunk_id)
input_tensor_bwd = input_tensors[master_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[master_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(
input_tensor_bwd,
output_tensor_bwd,
output_tensor_grad_bwd,
bwd_model_chunk_id=master_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[master_chunk_id]
)
cur_bwd_chunk_microbatch[master_chunk_id] += 1
if i == pp_size - 1: _ = send_backward(
bwd_wait_handles = send_backward(input_tensor_grad, input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True) tensor_shape,
elif i >= schedule['cooldown'][rank][0] - 1: config,
bwd_wait_handles = send_backward(input_tensor_grad, master_chunk_id,
tensor_shape, config, bwd_model_chunk_id, async_op=True) )
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
tensor_shape, config, bwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
if bwd_wait_handles is not None: # Launch any remaining grad reductions.
for req, req_handle in bwd_wait_handles.items(): if config.grad_sync_func is not None:
if req_handle is not None: enable_grad_sync()
req_handle.wait() config.grad_sync_func(model[master_chunk_id].parameters())
bwd_wait_handles = None
if config.finalize_model_grads_func is not None and not forward_only: if config.finalize_model_grads_func is not None and not forward_only:
...@@ -1132,4 +1141,8 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1132,4 +1141,8 @@ def forward_backward_pipelining_with_cutinhalf(
model, total_num_tokens if config.calculate_per_token_loss else None model, total_num_tokens if config.calculate_per_token_loss else None
) )
# Restore config.grad_sync_func and config.param_sync_func.
if forward_only:
config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func
return forward_data_store return forward_data_store
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment