Commit 2b81ee55 authored by dongcl's avatar dongcl
Browse files

dualpipev supports evaluation mode

parent f3434cc7
......@@ -597,22 +597,24 @@ def forward_backward_pipelining_with_cutinhalf(
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=(i == 0),
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
if not forward_only:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor_warmup)
master_cur_microbatch += 1
if i != schedule['warmup'][rank] - 1:
input_tensor, _ = send_forward_recv_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id)
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
if not forward_only:
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
else:
input_tensor, _ = recv_forward(
tensor_shape, config, master_chunk_id)
......@@ -644,14 +646,15 @@ def forward_backward_pipelining_with_cutinhalf(
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
is_first_microbatch=is_first_microbatch,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, is_first_microbatch),
current_microbatch=master_cur_microbatch
)
total_num_tokens += num_tokens.item()
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
if not forward_only:
input_tensors[master_chunk_id].append(
(master_cur_microbatch, input_tensor))
output_tensors[master_chunk_id].append(output_tensor)
master_cur_microbatch += 1
......@@ -659,13 +662,18 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor(
output_tensor_send, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True
if not forward_only:
input_tensor_slave_chunk = output_tensor.detach()
input_tensor_slave_chunk.requires_grad = True
else:
input_tensor_slave_chunk = output_tensor
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
......@@ -680,18 +688,22 @@ def forward_backward_pipelining_with_cutinhalf(
for req, req_handle in fwd_wait_handles_warmup.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
fwd_wait_handles_warmup = None
if not forward_only:
deallocate_output_tensor(
output_tensor_warmup, config.deallocate_pipeline_outputs)
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
set_dualpipe_chunk(slave_chunk_id)
output_tensor_slave_chunk, num_tokens = forward_step_no_model_graph(
forward_step_func,
......@@ -707,15 +719,16 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch,
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
if i == schedule['interleaved_forward'][rank] - 1:
firstFB_no_overlp = False
firstFB_no_overlp_handle = None
# last rank not overlap first F&B
......@@ -749,33 +762,37 @@ def forward_backward_pipelining_with_cutinhalf(
for _ in range(schedule['1b1w1f'][rank]):
# WeightGradStore.start_decouple()
if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
# WeightGradStore.end_decouple()
if fwd_wait_handles_slave_chunk is not None:
for req in fwd_wait_handles_slave_chunk:
req.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if fwd_wait_handles_send is not None:
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
fwd_wait_handles_send = None
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
if not forward_only:
# If asynchronous, the memory will rise.
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, slave_chunk_id)
# If asynchronous, the memory will rise.
input_tensor_slave_chunk, recv_forward_handle = recv_forward(
......@@ -806,15 +823,17 @@ def forward_backward_pipelining_with_cutinhalf(
current_microbatch=slave_cur_microbatch
)
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
total_num_tokens += num_tokens.item()
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
if not forward_only:
input_tensors[slave_chunk_id].append(
(slave_cur_microbatch, input_tensor_slave_chunk))
output_tensors[slave_chunk_id].append(output_tensor_slave_chunk)
slave_cur_microbatch += 1
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
if not forward_only:
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk,
tensor_shape, config, slave_chunk_id, async_op=True)
......@@ -847,36 +866,42 @@ def forward_backward_pipelining_with_cutinhalf(
checkpoint_activations_microbatch,
current_microbatch=fwd_microbatch
)
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
total_num_tokens += num_tokens.item()
output_tensors[fwd_model_chunk_id].append(output_tensor)
if not forward_only:
input_tensors[fwd_model_chunk_id].append(
(fwd_microbatch, input_tensor))
output_tensors[fwd_model_chunk_id].append(output_tensor)
if fwd_model_chunk_id == master_chunk_id:
master_cur_microbatch += 1
fwd_send_only = False
else:
slave_cur_microbatch += 1
fwd_send_only = (master_cur_microbatch ==
master_microbatch_max)
fwd_send_only = (master_cur_microbatch == master_microbatch_max)
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if fwd_send_only:
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
input_tensor = output_tensor.detach()
input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
input_tensor = output_tensor.detach()
input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else:
input_tensor = output_tensor
else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
......@@ -887,123 +912,128 @@ def forward_backward_pipelining_with_cutinhalf(
req_handle.wait()
firstFB_no_overlp_handle = None
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if not forward_only:
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if fwd_wait_handles is not None:
for req, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if not forward_only:
deallocate_output_tensor(
output_tensor, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id, async_op=True)
# only run backward
else:
if bwd_model_chunk_id == slave_chunk_id and slave_cur_microbatch < slave_microbatch_max:
input_tensor, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if not forward_only:
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[
1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(
0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_forward_recv_slave_forward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, fwd_model_chunk_id)
# swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
# Run cooldown phases
merged_input_tensors = []
merged_output_tensors = []
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
if len(input_tensors[bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[1 - bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size):
if not forward_only:
# Run cooldown phases
merged_input_tensors = []
merged_output_tensors = []
while len(input_tensors[0]) > 0 or len(input_tensors[1]) > 0:
if len(input_tensors[bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[bwd_model_chunk_id].pop(0), bwd_model_chunk_id))
if len(input_tensors[1 - bwd_model_chunk_id]) > 0:
merged_input_tensors.append(
input_tensors[1 - bwd_model_chunk_id].pop(0))
merged_output_tensors.append(
(output_tensors[1 - bwd_model_chunk_id].pop(0), 1 - bwd_model_chunk_id))
bwd_wait_handles_recv = None
for i in range(pp_size):
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req, req_handle in bwd_wait_handles_recv.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles_recv = None
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if bwd_wait_handles_recv is not None:
for req, req_handle in bwd_wait_handles_recv.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles_recv = None
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.start_decouple()
input_tensor_bwd = merged_input_tensors.pop(0)[1]
output_tensor_bwd, bwd_model_chunk_id = merged_output_tensors.pop(0)
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.start_decouple()
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.end_decouple()
input_tensor_grad = backward_step(
input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd, model_type, config
)
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
elif i >= schedule['cooldown'][rank][0] - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
tensor_shape, config, bwd_model_chunk_id, async_op=True)
else:
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
# if not args.dualpipe_no_dw_detach:
# WeightGradStore.end_decouple()
if i == pp_size - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
elif i >= schedule['cooldown'][rank][0] - 1:
bwd_wait_handles = send_backward(input_tensor_grad,
tensor_shape, config, bwd_model_chunk_id, async_op=True)
output_tensor_grad_bwd, bwd_wait_handles_recv = recv_backward(
tensor_shape, config, bwd_model_chunk_id, async_op=True)
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
if parallel_state.is_pipeline_last_stage() and (1 - bwd_model_chunk_id) == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad
else:
# send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward(input_tensor_grad,
tensor_shape, config, 1 - bwd_model_chunk_id)
# WeightGradStore.flush_chunk_grad()
# if i >= schedule['cooldown'][rank][0] - 1:
......@@ -1014,11 +1044,11 @@ def forward_backward_pipelining_with_cutinhalf(
# assert WeightGradStore.weight_grad_queue.empty()
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if bwd_wait_handles is not None:
for req, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
if config.finalize_model_grads_func is not None and not forward_only:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment