Commit 2b81ee55 authored by dongcl's avatar dongcl
Browse files

dualpipev supports evaluation mode

parent f3434cc7
...@@ -597,11 +597,12 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -597,11 +597,12 @@ def forward_backward_pipelining_with_cutinhalf(
config, config,
collect_non_loss_data, collect_non_loss_data,
checkpoint_activations_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=(i == 0), is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=master_cur_microbatch current_microbatch=master_cur_microbatch
) )
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[master_chunk_id].append( input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor)) (master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup) output_tensors[master_chunk_id].append(output_tensor_warmup)
...@@ -611,6 +612,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -611,6 +612,7 @@ def forward_backward_pipelining_with_cutinhalf(
if i != schedule['warmup'][rank] - 1: if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward( input_tensor, _ = send_forward_recv_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id) output_tensor_warmup, tensor_shape, config, master_chunk_id)
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs) output_tensor_warmup, config.deallocate_pipeline_outputs)
else: else:
...@@ -644,11 +646,12 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -644,11 +646,12 @@ def forward_backward_pipelining_with_cutinhalf(
config, config,
collect_non_loss_data, collect_non_loss_data,
checkpoint_activations_microbatch, checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch, is_first_microbatch=check_first_val_step(first_val_step, forward_only, is_first_microbatch),
current_microbatch=master_cur_microbatch current_microbatch=master_cur_microbatch
) )
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[master_chunk_id].append( input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor)) (master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor) output_tensors[master_chunk_id].append(output_tensor)
...@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_send.items(): for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs) output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if not forward_only:
input_tensor_slave_chunk = output_tensor.detach() input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True input_tensor_slave_chunk.requires_grad = True
else:
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward( input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True) tensor_shape, config, master_chunk_id, async_op=True)
...@@ -680,17 +688,21 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -680,17 +688,21 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_warmup.items(): for req, req_handle in fwd_wait_handles_warmup.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles_warmup = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs) output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items(): for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs) output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
set_dualpipe_chunk(slave_chunk_id) set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph( output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
...@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch, current_microbatch=slave_cur_microbatch,
) )
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[slave_chunk_id].append( input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk)) (slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk) output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1 slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1: if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp = False firstFB_no_overlp = False
firstFB_no_overlp_handle = None firstFB_no_overlp_handle = None
# last rank not overlap first F&B # last rank not overlap first F&B
...@@ -749,7 +762,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -749,7 +762,7 @@ def forward_backward_pipelining_with_cutinhalf(
for _ in range(schedule['1b1w1f'][rank]): for _ in range(schedule['1b1w1f'][rank]):
# WeightGradStore.start_decouple() # WeightGradStore.start_decouple()
if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1] input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0) output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
...@@ -762,17 +775,21 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -762,17 +775,21 @@ def forward_backward_pipelining_with_cutinhalf(
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk: for req in fwd_wait_handles_slave_chunk:
req.wait() req.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs) output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_wait_handles_send is not None: if fwd_wait_handles_send is not None:
for req, req_handle in fwd_wait_handles_send.items(): for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs) output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if not forward_only:
# If asynchronous, the memory will rise. # If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad, bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id) tensor_shape, config, slave_chunk_id)
...@@ -806,13 +823,15 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -806,13 +823,15 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch current_microbatch=slave_cur_microbatch
) )
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[slave_chunk_id].append( input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk)) (slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk) output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1 slave_cur_microbatch += 1
if not forward_only:
output_tensor_grad_bwd, _ = recv_backward( output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id) tensor_shape, config, slave_chunk_id)
...@@ -847,9 +866,12 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -847,9 +866,12 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch, checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch current_microbatch=fwd_microbatch
) )
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[fwd_model_chunk_id].append( input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor)) (fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor) output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id: if fwd_model_chunk_id == master_chunk_id:
...@@ -857,26 +879,29 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -857,26 +879,29 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_send_only = False fwd_send_only = False
else: else:
slave_cur_microbatch += 1 slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch == fwd_send_only = (master_cur_microbatch == master_microbatch_max)
master_microbatch_max)
# 同步上个阶段最后一个slave前向send # 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None: if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items(): for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs) output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_send_only: if fwd_send_only:
fwd_wait_handles = send_forward( fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else: else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
if not forward_only:
input_tensor = output_tensor.detach() input_tensor = output_tensor.detach()
input_tensor.requires_grad = True input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor = output_tensor
else: else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward( input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
...@@ -887,6 +912,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -887,6 +912,7 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle.wait() req_handle.wait()
firstFB_no_overlp_handle = None firstFB_no_overlp_handle = None
if not forward_only:
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items(): for req, req_handle in bwd_wait_handles.items():
if req_handle is not None: if req_handle is not None:
...@@ -907,9 +933,11 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -907,9 +933,11 @@ def forward_backward_pipelining_with_cutinhalf(
if req_handle is not None: if req_handle is not None:
req_handle.wait() req_handle.wait()
fwd_wait_handles = None fwd_wait_handles = None
if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs) output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad output_tensor_grad_bwd = input_tensor_grad
else: else:
...@@ -922,6 +950,7 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -922,6 +950,7 @@ def forward_backward_pipelining_with_cutinhalf(
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max: if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward( input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id) tensor_shape, config, slave_chunk_id)
if not forward_only:
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items(): for req, req_handle in bwd_wait_handles.items():
if req_handle is not None: if req_handle is not None:
...@@ -940,12 +969,13 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -940,12 +969,13 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd = input_tensor_grad output_tensor_grad_bwd = input_tensor_grad
else: else:
# send_backward_recv_slave_backward # send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad, output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id) tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks # swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
if not forward_only:
# Run cooldown phases # Run cooldown phases
merged_input_tensors = [] merged_input_tensors = []
merged_output_tensors = [] merged_output_tensors = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment