Commit 2b81ee55 authored by dongcl's avatar dongcl
Browse files

dualpipev supports evaluation mode

parent f3434cc7
...@@ -597,22 +597,24 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -597,22 +597,24 @@ def forward_backward_pipelining_with_cutinhalf(
config, config,
collect_non_loss_data, collect_non_loss_data,
checkpoint_activations_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=(i == 0), is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=master_cur_microbatch current_microbatch=master_cur_microbatch
) )
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append( if not forward_only:
(master_cur_microbatch, input_tensor)) input_tensors[master_chunk_id].append(
output_tensors[master_chunk_id].append(output_tensor_warmup) (master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
master_cur_microbatch += 1 master_cur_microbatch += 1
if i != schedule['warmup'][rank] - 1: if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward( input_tensor, _ = send_forward_recv_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id) output_tensor_warmup, tensor_shape, config, master_chunk_id)
deallocate_output_tensor( if not forward_only:
output_tensor_warmup, config.deallocate_pipeline_outputs) deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
else: else:
input_tensor, _ = recv_forward( input_tensor, _ = recv_forward(
tensor_shape, config, master_chunk_id) tensor_shape, config, master_chunk_id)
...@@ -644,14 +646,15 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -644,14 +646,15 @@ def forward_backward_pipelining_with_cutinhalf(
config, config,
collect_non_loss_data, collect_non_loss_data,
checkpoint_activations_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch, is_first_microbatch=check_first_val_step(first_val_step, forward_only, is_first_microbatch),
current_microbatch=master_cur_microbatch current_microbatch=master_cur_microbatch
) )
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append( if not forward_only:
(master_cur_microbatch, input_tensor)) input_tensors[master_chunk_id].append(
output_tensors[master_chunk_id].append(output_tensor) (master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
master_cur_microbatch += 1 master_cur_microbatch += 1
...@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_send.items(): for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
input_tensor_slave_chunk = output_tensor.detach() if not forward_only:
input_tensor_slave_chunk.requires_grad = True input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True
else:
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward( input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True) tensor_shape, config, master_chunk_id, async_op=True)
...@@ -680,18 +688,22 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -680,18 +688,22 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_warmup.items(): for req, req_handle in fwd_wait_handles_warmup.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None fwd_wait_handles_warmup = None
if not forward_only:
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items(): for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
set_dualpipe_chunk(slave_chunk_id) set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph( output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func, forward_step_func,
...@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch, current_microbatch=slave_cur_microbatch,
) )
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1 slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1: if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp = False firstFB_no_overlp = False
firstFB_no_overlp_handle = None firstFB_no_overlp_handle = None
# last rank not overlap first F&B # last rank not overlap first F&B
...@@ -749,33 +762,37 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -749,33 +762,37 @@ def forward_backward_pipelining_with_cutinhalf(
for _ in range(schedule['1b1w1f'][rank]): for _ in range(schedule['1b1w1f'][rank]):
# WeightGradStore.start_decouple() # WeightGradStore.start_decouple()
if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1] input_tensor_grad = backward_step(
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0) input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
# WeightGradStore.end_decouple() # WeightGradStore.end_decouple()
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: for req in fwd_wait_handles_slave_chunk:
req.wait() req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if fwd_wait_handles_send is not None: if fwd_wait_handles_send is not None:
for req, req_handle in fwd_wait_handles_send.items(): for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
# If asynchronous, the memory will rise. if not forward_only:
bwd_wait_handles = send_backward(input_tensor_grad, # If asynchronous, the memory will rise.
tensor_shape, config, slave_chunk_id) bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
# If asynchronous, the memory will rise. # If asynchronous, the memory will rise.
input_tensor_slave_chunk, recv_forward_handle = recv_forward( input_tensor_slave_chunk, recv_forward_handle = recv_forward(
...@@ -806,15 +823,17 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -806,15 +823,17 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch current_microbatch=slave_cur_microbatch
) )
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk) if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1 slave_cur_microbatch += 1
output_tensor_grad_bwd, _ = recv_backward( if not forward_only:
tensor_shape, config, slave_chunk_id) output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk, fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True) tensor_shape, config, slave_chunk_id, async_op=True)
...@@ -847,36 +866,42 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -847,36 +866,42 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch, checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch current_microbatch=fwd_microbatch
) )
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
if not forward_only:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id: if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1 master_cur_microbatch += 1
fwd_send_only = False fwd_send_only = False
else: else:
slave_cur_microbatch += 1 slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch == fwd_send_only = (master_cur_microbatch == master_microbatch_max)
master_microbatch_max)
# 同步上个阶段最后一个slave前向send # 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items(): for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if fwd_send_only: if fwd_send_only:
fwd_wait_handles = send_forward( fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else: else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor.detach() if not forward_only:
input_tensor.requires_grad = True input_tensor = output_tensor.detach()
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor = output_tensor
else: else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward( input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
...@@ -887,123 +912,128 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -887,123 +912,128 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle.wait() req_handle.wait()
firstFB_no_overlp_handle = None firstFB_no_overlp_handle = None
if bwd_wait_handles is not None: if not forward_only:
for req, req_handle in bwd_wait_handles.items(): if bwd_wait_handles is not None:
if req_handle is not None: for req, req_handle in bwd_wait_handles.items():
req_handle.wait() if req_handle is not None:
bwd_wait_handles = None req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[ input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1] 1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop( output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0) 0)
input_tensor_grad = backward_step( input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
) )
if fwd_wait_handles is not None: if fwd_wait_handles is not None:
for req, req_handle in fwd_wait_handles.items(): for req, req_handle in fwd_wait_handles.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles = None fwd_wait_handles = None
deallocate_output_tensor( if not forward_only:
output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if not forward_only:
output_tensor_grad_bwd = input_tensor_grad if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
else: output_tensor_grad_bwd = input_tensor_grad
# send_backward_recv_slave_backward else:
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad, # send_backward_recv_slave_backward
tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
# only run backward # only run backward
else: else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max: if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward( input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id) tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None: if not forward_only:
for req, req_handle in bwd_wait_handles.items(): if bwd_wait_handles is not None:
if req_handle is not None: for req, req_handle in bwd_wait_handles.items():
req_handle.wait() if req_handle is not None:
bwd_wait_handles = None req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[ if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
1] output_tensor_grad_bwd = input_tensor_grad
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop( else:
0) # send_backward_recv_slave_backward
input_tensor_grad = backward_step( output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config tensor_shape, config, fwd_model_chunk_id)
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks # swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
# Run cooldown phases if not forward_only:
merged_input_tensors = [] # Run cooldown phases
merged_output_tensors = [] merged_input_tensors = []
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0: merged_output_tensors = []
if len(input_tensors[bwd_model_chunk_id]) > 0: while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
merged_input_tensors.append( if len(input_tensors[bwd_model_chunk_id]) > 0:
input_tensors[bwd_model_chunk_id].pop(0)) merged_input_tensors.append(
merged_output_tensors.append( input_tensors[bwd_model_chunk_id].pop(0))
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id)) merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append( if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
input_tensors[1 - bwd_model_chunk_id].pop(0)) merged_input_tensors.append(
merged_output_tensors.append( input_tensors[1 - bwd_model_chunk_id].pop(0))
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id)) merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size): bwd_wait_handles_recv = None
for i in range(pp_size):
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items(): for req, req_handle in bwd_wait_handles.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
bwd_wait_handles = None bwd_wait_handles = None
if bwd_wait_handles_recv is not None: if bwd_wait_handles_recv is not None:
for req, req_handle in bwd_wait_handles_recv.items(): for req, req_handle in bwd_wait_handles_recv.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
bwd_wait_handles_recv = None bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
# if not args.dualpipe_no_dw_detach: input_tensor_bwd = merged_input_tensors.pop(0)[1]
# WeightGradStore.start_decouple() output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
input_tensor_grad = backward_step( # if not args.dualpipe_no_dw_detach:
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config # WeightGradStore.start_decouple()
)
# if not args.dualpipe_no_dw_detach: input_tensor_grad = backward_step(
# WeightGradStore.end_decouple() input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if i == pp_size - 1: # if not args.dualpipe_no_dw_detach:
bwd_wait_handles = send_backward(input_tensor_grad, # WeightGradStore.end_decouple()
tensor_shape, config, bwd_model_chunk_id, async_op=True)
elif i >= schedule['cooldown'][rank][0] - 1: if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad, bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True) tensor_shape, config, bwd_model_chunk_id, async_op=True)
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward( elif i >= schedule['cooldown'][rank][0] - 1:
tensor_shape, config, bwd_model_chunk_id, async_op=True) bwd_wait_handles = send_backward(input_tensor_grad,
else: tensor_shape, config, bwd_model_chunk_id, async_op=True)
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id: output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
output_tensor_grad_bwd = input_tensor_grad tensor_shape, config, bwd_model_chunk_id, async_op=True)
else: else:
# send_backward_recv_slave_backward if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad, output_tensor_grad_bwd = input_tensor_grad
tensor_shape, config, 1 - bwd_model_chunk_id) else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
# WeightGradStore.flush_chunk_grad() # WeightGradStore.flush_chunk_grad()
# if i >= schedule['cooldown'][rank][0] - 1: # if i >= schedule['cooldown'][rank][0] - 1:
...@@ -1014,11 +1044,11 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1014,11 +1044,11 @@ def forward_backward_pipelining_with_cutinhalf(
# assert WeightGradStore.weight_grad_queue.empty() # assert WeightGradStore.weight_grad_queue.empty()
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items(): for req, req_handle in bwd_wait_handles.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
bwd_wait_handles = None bwd_wait_handles = None
if config.finalize_model_grads_func is not None and not forward_only: if config.finalize_model_grads_func is not None and not forward_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment