Commit cb1230db authored by dongcl's avatar dongcl
Browse files

rewrite dualpipev_schedules; deduplicate the code

parent a58a2da6
......@@ -557,6 +557,82 @@ def forward_backward_pipelining_with_cutinhalf(
disable_grad_sync()
def combined_forward_backward_helper(
fwd_model_chunk_id,
bwd_model_chunk_id,
fwd_input_tensor=None,
bwd_output_tensor_grad=None
):
"""Helper method to run combined forward and backward step"""
# forward prepare
fwd_microbatch_id = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
f_context = contextlib.nullcontext()
set_dualpipe_chunk(fwd_model_chunk_id)
# backward prepare
b_context = contextlib.nullcontext()
bwd_input_tensor = input_tensors[bwd_model_chunk_id].pop(0)[1]
bwd_output_tensor = output_tensors[bwd_model_chunk_id].pop(0)
output_tensor, num_tokens, input_tensor_grad = forward_backward_step(
forward_step_func,
data_iterator[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
model[fwd_model_chunk_id] if fwd_model_chunk_id is not None else None,
num_microbatches,
fwd_input_tensor,
forward_data_store,
model[bwd_model_chunk_id] if bwd_model_chunk_id is not None else None,
bwd_input_tensor,
bwd_output_tensor,
bwd_output_tensor_grad,
config,
f_context=f_context,
b_context=b_context,
collect_non_loss_data=collect_non_loss_data,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=fwd_microbatch_id,
)
# forward post process
if fwd_model_chunk_id is not None:
with f_context:
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[fwd_model_chunk_id].append((fwd_microbatch_id, fwd_input_tensor))
output_tensors[fwd_model_chunk_id].append(output_tensor)
# backward post process
if b_model_chunk_id:
with b_context:
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.grad_sync_func is not None:
grad_sync_virtual_microbatch_id = (
b_virtual_microbatch_id - pipeline_parallel_rank
)
if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_virtual_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(
grad_sync_virtual_microbatch_id, forward=False
)
enable_grad_sync()
config.grad_sync_func[grad_sync_chunk_id](
model[grad_sync_chunk_id].parameters()
)
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
if input_tensor is not None:
assert input_tensor_grad is not None
return output_tensor, input_tensor_grad
# Compute number of steps for each stage
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
rank = parallel_state.get_pipeline_model_parallel_rank()
......@@ -583,35 +659,74 @@ def forward_backward_pipelining_with_cutinhalf(
master_microbatch_max = num_microbatches
slave_microbatch_max = num_microbatches * 2
set_dualpipe_chunk(master_chunk_id)
checkpoint_activations_microbatch = None
input_tensor = recv_forward(tensor_shape, config, master_chunk_id, step=0)[0]
fwd_wait_handles_warmup = None
# Run warmup forward passes
for i in range(schedule['warmup'][rank]):
output_tensor_warmup, num_tokens = forward_step_no_model_graph(
def forward_step_helper(input_tensor, model_chunk_id, cur_microbatch, is_first_microbatch=False):
set_dualpipe_chunk(model_chunk_id)
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
model_chunk_id,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=master_cur_microbatch
is_first_microbatch=is_first_microbatch,
current_microbatch=cur_microbatch
)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
input_tensors[model_chunk_id].append(
(cur_microbatch, input_tensor))
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor
def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, is_last_microbatch=False):
# # launch grad synchronization (default)
# if config.grad_sync_func is None and is_last_microbatch:
# enable_grad_sync()
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
# # launch grad synchronization (custom grad sync)
# # Note: Asynchronous communication tends to slow down compute.
# # To reduce idling from mismatched microbatch times, we launch
# # asynchronous communication at the same time across the
# # pipeline-parallel group.
# if config.grad_sync_func is not None:
# grad_sync_virtual_microbatch_id = virtual_microbatch_id - pipeline_parallel_rank
# if grad_sync_virtual_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
# grad_sync_virtual_microbatch_id
# ):
# grad_sync_chunk_id = get_model_chunk_id(
# grad_sync_virtual_microbatch_id, forward=False
# )
# enable_grad_sync()
# config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
# synchronized_model_chunks.add(grad_sync_chunk_id)
# disable_grad_sync()
return input_tensor_grad
# Run warmup forward passes
input_tensor, _ = recv_forward(tensor_shape, config, master_chunk_id)
for i in range(schedule['warmup'][rank]):
is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0)
output_tensor_warmup = forward_step_helper(
input_tensor,
master_chunk_id,
master_cur_microbatch,
is_first_microbatch=is_first_microbatch
)
master_cur_microbatch += 1
......@@ -639,29 +754,13 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_wait_handles = None
is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0)
set_dualpipe_chunk(master_chunk_id)
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
master_chunk_id,
data_iterator[master_chunk_id],
model[master_chunk_id],
num_microbatches,
is_first_microbatch = check_first_val_step(first_val_step, forward_only, is_first_microbatch)
output_tensor = forward_step_helper(
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, is_first_microbatch),
current_microbatch=master_cur_microbatch
master_chunk_id,
master_cur_microbatch,
is_first_microbatch=is_first_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
master_cur_microbatch += 1
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None:
......@@ -676,19 +775,16 @@ def forward_backward_pipelining_with_cutinhalf(
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if not forward_only:
input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True
input_tensor_slave = output_tensor.detach()
input_tensor_slave.requires_grad = True
else:
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
input_tensor_slave = output_tensor
else:
input_tensor_slave_chunk, _ = recv_forward(
input_tensor_slave, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles_warmup is not None:
for req, req_handle in fwd_wait_handles_warmup.items():
......@@ -710,28 +806,13 @@ def forward_backward_pipelining_with_cutinhalf(
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0)
output_tensor_slave_chunk = forward_step_helper(
input_tensor_slave,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch,
slave_cur_microbatch,
is_first_microbatch=is_first_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if not forward_only:
......@@ -765,8 +846,6 @@ def forward_backward_pipelining_with_cutinhalf(
# Run 1b1w1f stages for slave chunk
bwd_wait_handles = None
for _ in range(schedule['1b1w1f'][rank]):
# WeightGradStore.start_decouple()
if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
......@@ -775,8 +854,6 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
# WeightGradStore.end_decouple()
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
......@@ -800,12 +877,9 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape, config, slave_chunk_id)
# If asynchronous, the memory will rise.
input_tensor_slave_chunk, recv_forward_handle = recv_forward(
input_tensor_slave, recv_forward_handle = recv_forward(
tensor_shape, config, slave_chunk_id)
# 1w: Weight Grad Compute
# WeightGradStore.pop()
if recv_forward_handle is not None:
for req, handle in recv_forward_handle.items():
if handle is not None:
......@@ -813,27 +887,12 @@ def forward_backward_pipelining_with_cutinhalf(
recv_forward_handle = None
# 1F: Forward pass
set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
output_tensor = forward_step_helper(
input_tensor_slave,
slave_chunk_id,
data_iterator[slave_chunk_id],
model[slave_chunk_id],
num_microbatches,
input_tensor_slave_chunk,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=slave_cur_microbatch
slave_cur_microbatch,
is_first_microbatch=False
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if not forward_only:
......@@ -844,6 +903,8 @@ def forward_backward_pipelining_with_cutinhalf(
tensor_shape, config, slave_chunk_id, async_op=True)
# Run overlaping f&bw stages
fwd_wait_handles = None
bwd_wait_handles = None
fwd_wait_handles_recv = None
fwd_model_chunk_id = master_chunk_id
bwd_model_chunk_id = slave_chunk_id
......@@ -858,105 +919,121 @@ def forward_backward_pipelining_with_cutinhalf(
only_bwd = True
if not only_bwd:
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
set_dualpipe_chunk(fwd_model_chunk_id)
def pp_pre_forward():
nonlocal fwd_wait_handles_recv
if fwd_wait_handles_recv is not None:
for req, req_handle in fwd_wait_handles_recv.items():
req_handle.wait()
fwd_wait_handles_recv = None
output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func,
fwd_model_chunk_id,
data_iterator[fwd_model_chunk_id],
model[fwd_model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch == master_microbatch_max)
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
if fwd_wait_handles_recv is not None:
for req, req_handle in fwd_wait_handles_recv.items():
req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_recv = None
def pp_post_forward(output_tensor):
nonlocal master_cur_microbatch
nonlocal slave_cur_microbatch
nonlocal fwd_wait_handles
nonlocal fwd_wait_handles_slave_chunk
nonlocal firstFB_no_overlp_handle
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch == master_microbatch_max)
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
input_tensor = output_tensor.detach()
input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if not forward_only and firstFB_no_overlp_handle is not None:
for req, req_handle in firstFB_no_overlp_handle.items():
if req_handle is not None:
req_handle.wait()
firstFB_no_overlp_handle = None
if fwd_send_only:
input_tensor = None
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
if not forward_only:
input_tensor = output_tensor.detach()
input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if not forward_only:
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if not forward_only and firstFB_no_overlp_handle is not None:
for req, req_handle in firstFB_no_overlp_handle.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
firstFB_no_overlp_handle = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
return input_tensor
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
def pp_pre_backward():
nonlocal bwd_wait_handles
if fwd_wait_handles is not None:
for req, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if bwd_wait_handles is not None:
for _, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
def pp_post_backward(input_tensor_grad):
nonlocal fwd_wait_handles
nonlocal bwd_wait_handles
if fwd_wait_handles is not None:
for _, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad = input_tensor_grad
else:
output_tensor_grad, bwd_wait_handles = send_backward_recv_slave_backward(
input_tensor_grad,
tensor_shape,
config,
fwd_model_chunk_id,
async_op=True
)
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
output_tensor_grad = None
return output_tensor_grad
# forward
pp_pre_forward()
fwd_microbatch = master_cur_microbatch if fwd_model_chunk_id == master_chunk_id else slave_cur_microbatch
output_tensor = forward_step_helper(
input_tensor,
fwd_model_chunk_id,
fwd_microbatch,
is_first_microbatch=False
)
input_tensor = pp_post_forward(output_tensor)
# backward
pp_pre_backward()
if not forward_only:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
else:
input_tensor_grad = None
output_tensor_grad_bwd = pp_post_backward(input_tensor_grad)
# only run backward
else:
......@@ -970,13 +1047,9 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
......@@ -1022,15 +1095,7 @@ def forward_backward_pipelining_with_cutinhalf(
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.start_decouple()
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.end_decouple()
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
......@@ -1048,15 +1113,6 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
# WeightGradStore.flush_chunk_grad()
# if i >= schedule['cooldown'][rank][0] - 1:
# WeightGradStore.pop_single()
# for _ in range(schedule['cooldown'][rank][2] - 1):
# WeightGradStore.pop_single()
# assert WeightGradStore.weight_grad_queue.empty()
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment