Commit 2b81ee55 authored by dongcl's avatar dongcl
Browse files

dualpipev supports evaluation mode

parent f3434cc7
......@@ -597,11 +597,12 @@ def forward_backward_pipelining_with_cutinhalf(
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=(i == 0),
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
......@@ -611,6 +612,7 @@ def forward_backward_pipelining_with_cutinhalf(
if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id)
if not forward_only:
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
else:
......@@ -644,11 +646,12 @@ def forward_backward_pipelining_with_cutinhalf(
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, is_first_microbatch),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
......@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if not forward_only:
input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True
else:
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
......@@ -680,17 +688,21 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_warmup.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_warmup = None
if not forward_only:
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
......@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch,
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp = False
firstFB_no_overlp_handle = None
# last rank not overlap first F&B
......@@ -749,7 +762,7 @@ def forward_backward_pipelining_with_cutinhalf(
for _ in range(schedule['1b1w1f'][rank]):
# WeightGradStore.start_decouple()
if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
......@@ -762,17 +775,21 @@ def forward_backward_pipelining_with_cutinhalf(
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_wait_handles_send is not None:
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if not forward_only:
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
......@@ -806,13 +823,15 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if not forward_only:
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
......@@ -847,9 +866,12 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id:
......@@ -857,26 +879,29 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
fwd_send_only = (master_cur_microbatch == master_microbatch_max)
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
if not forward_only:
input_tensor = output_tensor.detach()
input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
......@@ -887,6 +912,7 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle.wait()
firstFB_no_overlp_handle = None
if not forward_only:
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
......@@ -907,9 +933,11 @@ def forward_backward_pipelining_with_cutinhalf(
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
......@@ -922,6 +950,7 @@ def forward_backward_pipelining_with_cutinhalf(
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if not forward_only:
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
......@@ -940,12 +969,13 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
if not forward_only:
# Run cooldown phases
merged_input_tensors = []
merged_output_tensors = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment