Commit 62f16817 authored by dongcl's avatar dongcl
Browse files

evaluate support dualpipev

parent 2385a133
...@@ -48,6 +48,7 @@ class PipelineFeature(AbstractFeature): ...@@ -48,6 +48,7 @@ class PipelineFeature(AbstractFeature):
train_step, train_step,
_allreduce_embedding_grads_wrapper _allreduce_embedding_grads_wrapper
) )
from dcu_megatron.training.training import evaluate
patch_manager.register_patch( patch_manager.register_patch(
'megatron.training.training.get_model', get_model) 'megatron.training.training.get_model', get_model)
...@@ -64,6 +65,10 @@ class PipelineFeature(AbstractFeature): ...@@ -64,6 +65,10 @@ class PipelineFeature(AbstractFeature):
patch_manager.register_patch( patch_manager.register_patch(
'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads', _allreduce_embedding_grads_wrapper) 'megatron.core.distributed.finalize_model_grads._allreduce_embedding_grads', _allreduce_embedding_grads_wrapper)
# use first rank
patch_manager.register_patch(
'megatron.training.training.evaluate', evaluate)
if ( if (
args.schedule_method == "interleaved_1f1b" args.schedule_method == "interleaved_1f1b"
and args.combined_1f1b and args.combined_1f1b
......
...@@ -25,6 +25,8 @@ from megatron.core.pipeline_parallel.schedules import ( ...@@ -25,6 +25,8 @@ from megatron.core.pipeline_parallel.schedules import (
check_first_val_step, check_first_val_step,
finish_embedding_wgrad_compute finish_embedding_wgrad_compute
) )
from dcu_megatron.training.utils import print_rank_message
# from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore # from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
...@@ -584,7 +586,11 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -584,7 +586,11 @@ def forward_backward_pipelining_with_cutinhalf(
fwd_model_chunk_id, fwd_model_chunk_id,
bwd_model_chunk_id, bwd_model_chunk_id,
fwd_input_tensor=None, fwd_input_tensor=None,
bwd_output_tensor_grad=None bwd_output_tensor_grad=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
): ):
"""Helper method to run combined forward and backward step""" """Helper method to run combined forward and backward step"""
# forward prepare # forward prepare
...@@ -670,9 +676,9 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -670,9 +676,9 @@ def forward_backward_pipelining_with_cutinhalf(
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda() total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
forward_data_store = [] forward_data_store = []
if not forward_only:
input_tensors = [[], []] input_tensors = [[], []]
output_tensors = [[], []] output_tensors = [[], []]
output_tensor_grads = [[], []]
master_chunk_id = 0 master_chunk_id = 0
slave_chunk_id = 1 slave_chunk_id = 1
...@@ -681,10 +687,23 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -681,10 +687,23 @@ def forward_backward_pipelining_with_cutinhalf(
num_chunk_max_microbatch = [num_microbatches, num_microbatches * 2] num_chunk_max_microbatch = [num_microbatches, num_microbatches * 2]
checkpoint_activations_microbatch = None checkpoint_activations_microbatch = None
fwd_wait_handles_warmup = None
def forward_step_helper(input_tensor, model_chunk_id, cur_microbatch, is_first_microbatch=False): def wait_comm_handles(comm_handles):
if comm_handles is None:
return
for _, req_handle in comm_handles.items():
if req_handle is not None:
req_handle.wait()
comm_handles = None
def forward_step_helper(model_chunk_id, cur_microbatch, is_first_microbatch=False):
set_dualpipe_chunk(model_chunk_id) set_dualpipe_chunk(model_chunk_id)
if not forward_only:
offset = cur_bwd_chunk_microbatch[model_chunk_id]
input_tensor = input_tensors[model_chunk_id][cur_microbatch - offset]
else:
input_tensor = input_tensors[model_chunk_id][0]
output_tensor, num_tokens = forward_step_no_model_graph( output_tensor, num_tokens = forward_step_no_model_graph(
forward_step_func, forward_step_func,
model_chunk_id, model_chunk_id,
...@@ -699,17 +718,17 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -699,17 +718,17 @@ def forward_backward_pipelining_with_cutinhalf(
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
current_microbatch=cur_microbatch current_microbatch=cur_microbatch
) )
output_tensors[model_chunk_id].append(output_tensor)
nonlocal total_num_tokens nonlocal total_num_tokens
total_num_tokens += num_tokens.item() total_num_tokens += num_tokens.item()
if not forward_only: if forward_only:
input_tensors[model_chunk_id].append( input_tensors[model_chunk_id].pop(0)
(cur_microbatch, input_tensor)) output_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].append(output_tensor)
return output_tensor return output_tensor
def backward_step_helper(input_tensor, output_tensor, output_tensor_grad, bwd_model_chunk_id=None, bwd_cur_microbatch=None): def backward_step_helper(model_chunk_id, bwd_cur_microbatch=None):
nonlocal master_chunk_id nonlocal master_chunk_id
nonlocal slave_chunk_id nonlocal slave_chunk_id
nonlocal num_chunk_max_microbatch nonlocal num_chunk_max_microbatch
...@@ -721,207 +740,187 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -721,207 +740,187 @@ def forward_backward_pipelining_with_cutinhalf(
# bubble. # bubble.
if ( if (
bwd_cur_microbatch is not None bwd_cur_microbatch is not None
and bwd_cur_microbatch == num_chunk_max_microbatch[bwd_model_chunk_id] - 1 and bwd_cur_microbatch == num_chunk_max_microbatch[model_chunk_id] - 1
): ):
if ( if (
config.grad_sync_func is None config.grad_sync_func is None
or (bwd_model_chunk_id == slave_chunk_id and parallel_state.is_pipeline_last_stage()) or (model_chunk_id == slave_chunk_id and parallel_state.is_pipeline_last_stage())
or (bwd_model_chunk_id == master_chunk_id and parallel_state.is_pipeline_first_stage()) or (model_chunk_id == master_chunk_id and parallel_state.is_pipeline_first_stage())
): ):
enable_grad_sync() enable_grad_sync()
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step( input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config input_tensor, output_tensor, output_tensor_grad, model_type, config
) )
return input_tensor_grad return input_tensor_grad
output_tensor_master_send = None
output_tensor_slave_send = None
fwd_wait_recv_handles = [None, None]
fwd_wait_send_handles = [None, None]
bwd_wait_recv_handles = [None, None]
bwd_wait_send_handles = [None, None]
# Run warmup forward passes # Run warmup forward passes
input_tensor, _ = recv_forward(tensor_shape, config, master_chunk_id) input_tensor, _ = recv_forward(tensor_shape, config, master_chunk_id)
input_tensors[master_chunk_id].append(input_tensor)
for i in range(schedule['warmup'][rank]): for i in range(schedule['warmup'][rank]):
wait_comm_handles(fwd_wait_recv_handles[master_chunk_id])
# recv for next iteration
input_tensor, fwd_wait_recv_handles[master_chunk_id] = recv_forward(tensor_shape, config, master_chunk_id, async_op=True)
input_tensors[master_chunk_id].append(input_tensor)
is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0) is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0)
output_tensor_warmup = forward_step_helper( output_tensor = forward_step_helper(
input_tensor,
master_chunk_id, master_chunk_id,
cur_fwd_chunk_microbatch[master_chunk_id], cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[master_chunk_id] += 1 cur_fwd_chunk_microbatch[master_chunk_id] += 1
if i != schedule['warmup'][rank] - 1: if fwd_wait_send_handles[master_chunk_id] is not None:
input_tensor, _ = send_forward_recv_forward( for req, req_handle in fwd_wait_send_handles[master_chunk_id].items():
output_tensor_warmup, tensor_shape, config, master_chunk_id) if req_handle is not None:
req_handle.wait()
fwd_wait_send_handles[master_chunk_id] = None
if not forward_only: if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(output_tensor_master_send, config.deallocate_pipeline_outputs)
output_tensor_warmup, config.deallocate_pipeline_outputs)
else: output_tensor_master_send = output_tensor
input_tensor, _ = recv_forward( fwd_wait_send_handles[master_chunk_id] = send_forward(output_tensor_master_send, tensor_shape, config, master_chunk_id, async_op=True)
tensor_shape, config, master_chunk_id)
fwd_wait_handles_warmup = send_forward(
output_tensor_warmup, tensor_shape, config, master_chunk_id, async_op=True)
# Run interleaved forward passes for two model chunk # Run interleaved forward passes for two model chunk
fwd_wait_handles = None
fwd_wait_handles_slave_chunk = None
fwd_wait_handles_send = None
for i in range(schedule['interleaved_forward'][rank]): for i in range(schedule['interleaved_forward'][rank]):
if fwd_wait_handles is not None: wait_comm_handles(fwd_wait_recv_handles[master_chunk_id])
for req, req_handle in fwd_wait_handles.items():
if req_handle is not None: if not parallel_state.is_pipeline_last_stage():
req_handle.wait() input_tensor_slave, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id, async_op=True)
fwd_wait_handles = None input_tensors[slave_chunk_id].append(input_tensor_slave)
is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0) is_first_microbatch = parallel_state.is_pipeline_last_stage(ignore_virtual=True) and (i == 0)
is_first_microbatch = check_first_val_step(first_val_step, forward_only, is_first_microbatch) is_first_microbatch = check_first_val_step(first_val_step, forward_only, is_first_microbatch)
output_tensor = forward_step_helper( output_tensor_master = forward_step_helper(
input_tensor,
master_chunk_id, master_chunk_id,
cur_fwd_chunk_microbatch[master_chunk_id], cur_fwd_chunk_microbatch[master_chunk_id],
is_first_microbatch=is_first_microbatch is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[master_chunk_id] += 1 cur_fwd_chunk_microbatch[master_chunk_id] += 1
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True) and fwd_wait_handles_send is not None: if not parallel_state.is_pipeline_last_stage():
for req, req_handle in fwd_wait_handles_send.items(): wait_comm_handles(fwd_wait_send_handles[master_chunk_id])
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_send = None
if not forward_only: if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(output_tensor_master_send, config.deallocate_pipeline_outputs)
output_tensor_send, config.deallocate_pipeline_outputs)
if parallel_state.is_pipeline_last_stage(ignore_virtual=True): output_tensor_master_send = output_tensor_master
fwd_wait_send_handles[master_chunk_id] = send_forward(
output_tensor_master_send, tensor_shape, config, master_chunk_id, async_op=True)
# prepare input for slave chunk
if parallel_state.is_pipeline_last_stage():
if not forward_only: if not forward_only:
input_tensor_slave = output_tensor.detach() input_tensor_slave = output_tensor_master.detach()
input_tensor_slave.requires_grad = True input_tensor_slave.requires_grad = True
else: else:
input_tensor_slave = output_tensor input_tensor_slave = output_tensor_master
else: input_tensors[slave_chunk_id].append(input_tensor_slave)
input_tensor_slave, _ = recv_forward(
tensor_shape, config, slave_chunk_id)
input_tensor, fwd_wait_handles = recv_forward(
tensor_shape, config, master_chunk_id, async_op=True)
if fwd_wait_handles_warmup is not None:
for req, req_handle in fwd_wait_handles_warmup.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_warmup = None
if not forward_only: if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(output_tensor_master, config.deallocate_pipeline_outputs)
output_tensor_warmup, config.deallocate_pipeline_outputs) else:
wait_comm_handles(fwd_wait_recv_handles[slave_chunk_id])
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only: # recv input tensor for master clunk
deallocate_output_tensor( input_tensor, fwd_wait_recv_handles[master_chunk_id] = recv_forward(tensor_shape, config, master_chunk_id, async_op=True)
output_tensor_slave_chunk, config.deallocate_pipeline_outputs) input_tensors[master_chunk_id].append(input_tensor)
# slave forward
is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0) is_first_microbatch = check_first_val_step(first_val_step, forward_only, i == 0)
output_tensor_slave_chunk = forward_step_helper( output_tensor_slave = forward_step_helper(
input_tensor_slave,
slave_chunk_id, slave_chunk_id,
cur_fwd_chunk_microbatch[slave_chunk_id], cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=is_first_microbatch is_first_microbatch=is_first_microbatch
) )
cur_fwd_chunk_microbatch[slave_chunk_id] += 1 cur_fwd_chunk_microbatch[slave_chunk_id] += 1
wait_comm_handles(fwd_wait_send_handles[slave_chunk_id])
if not forward_only: if not forward_only:
if i == schedule['interleaved_forward'][rank] - 1: deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs)
firstFB_no_overlp_handle = None
# last rank not overlap first F&B
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
output_tensor_grad_bwd, firstFB_no_overlp_handle = recv_backward(
tensor_shape, config, slave_chunk_id, async_op=True)
else:
output_tensor_grad_bwd, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk, output_tensor_slave_send = output_tensor_slave
tensor_shape, config, slave_chunk_id, async_op=True) fwd_wait_send_handles[slave_chunk_id] = send_forward(output_tensor_slave_send, tensor_shape, config, slave_chunk_id, async_op=True)
if not parallel_state.is_pipeline_last_stage(ignore_virtual=True): # check whether data transmission is completed.
output_tensor_send = output_tensor wait_comm_handles(fwd_wait_send_handles[master_chunk_id])
fwd_wait_handles_send = send_forward( if not forward_only:
output_tensor_send, tensor_shape, config, master_chunk_id, async_op=True) deallocate_output_tensor(output_tensor_master_send, config.deallocate_pipeline_outputs)
else:
# custom_backward requires output_tensor.numel() == 1
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
if fwd_wait_handles is not None: wait_comm_handles(fwd_wait_send_handles[slave_chunk_id])
for req, req_handle in fwd_wait_handles.items(): if not forward_only:
if req_handle is not None: deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs)
req_handle.wait()
fwd_wait_handles = None
# Run 1b1w1f stages for slave chunk # Run 1b1w1f stages for slave chunk
bwd_wait_handles = None if not forward_only:
if parallel_state.is_pipeline_last_stage():
output_tensor_grad, bwd_wait_recv_handles[slave_chunk_id] = recv_backward(
tensor_shape, config, slave_chunk_id, async_op=True)
else:
output_tensor_grad, _ = recv_backward(
tensor_shape, config, slave_chunk_id)
output_tensor_grads[slave_chunk_id].append(output_tensor_grad)
if not forward_only and parallel_state.is_pipeline_first_stage():
deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs)
for _ in range(schedule['1b1w1f'][rank]): for _ in range(schedule['1b1w1f'][rank]):
# If asynchronous, the memory will rise. TODO dongcl
input_tensor_slave, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id)
input_tensors[slave_chunk_id].append(input_tensor_slave)
if not forward_only: if not forward_only:
input_tensor_bwd = input_tensors[slave_chunk_id].pop(0)[1] input_tensor_grad = backward_step_helper(slave_chunk_id)
output_tensor_bwd = output_tensors[slave_chunk_id].pop(0)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd)
cur_bwd_chunk_microbatch[slave_chunk_id] += 1 cur_bwd_chunk_microbatch[slave_chunk_id] += 1
if fwd_wait_handles_slave_chunk is not None: # If asynchronous, the memory will rise.
for req in fwd_wait_handles_slave_chunk: bwd_wait_send_handles[slave_chunk_id] = send_backward(input_tensor_grad, tensor_shape, config, slave_chunk_id)
req.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if fwd_wait_handles_send is not None: wait_comm_handles(fwd_wait_send_handles[slave_chunk_id])
for req, req_handle in fwd_wait_handles_send.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_send = None
if not forward_only: if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(output_tensor_slave_send, config.deallocate_pipeline_outputs)
output_tensor, config.deallocate_pipeline_outputs)
if not forward_only: if not forward_only:
# If asynchronous, the memory will rise. output_tensor_grad, _ = recv_backward(tensor_shape, config, slave_chunk_id)
bwd_wait_handles = send_backward(input_tensor_grad, output_tensor_grads[slave_chunk_id].append(output_tensor_grad)
tensor_shape, config, slave_chunk_id)
# If asynchronous, the memory will rise.
input_tensor_slave, recv_forward_handle = recv_forward(
tensor_shape, config, slave_chunk_id)
if recv_forward_handle is not None: # 1F: Forward pass
for req, handle in recv_forward_handle.items(): if fwd_wait_recv_handles[slave_chunk_id] is not None:
for req, handle in fwd_wait_recv_handles[slave_chunk_id].items():
if handle is not None: if handle is not None:
handle.wait() handle.wait()
recv_forward_handle = None fwd_wait_recv_handles[slave_chunk_id] = None
# 1F: Forward pass output_tensor_slave = forward_step_helper(
output_tensor = forward_step_helper(
input_tensor_slave,
slave_chunk_id, slave_chunk_id,
cur_fwd_chunk_microbatch[slave_chunk_id], cur_fwd_chunk_microbatch[slave_chunk_id],
is_first_microbatch=False is_first_microbatch=False
) )
cur_fwd_chunk_microbatch[slave_chunk_id] += 1 cur_fwd_chunk_microbatch[slave_chunk_id] += 1
if not forward_only: # check whether backward data transmission is completed.
output_tensor_grad_bwd, _ = recv_backward( wait_comm_handles(bwd_wait_send_handles[slave_chunk_id])
tensor_shape, config, slave_chunk_id)
fwd_wait_handles_slave_chunk = send_forward(output_tensor_slave_chunk, output_tensor_slave_send = output_tensor_slave
tensor_shape, config, slave_chunk_id, async_op=True) fwd_wait_send_handles[slave_chunk_id] = send_forward(output_tensor_slave_send, tensor_shape, config, slave_chunk_id, async_op=True)
# Run overlaping f&bw stages # Run overlaping f&bw stages
fwd_wait_handles = None fwd_wait_send_recv_handles = None
bwd_wait_handles = None bwd_wait_send_recv_handles = None
fwd_wait_handles_recv = None
fwd_model_chunk_id = master_chunk_id fwd_model_chunk_id = master_chunk_id
bwd_model_chunk_id = slave_chunk_id bwd_model_chunk_id = slave_chunk_id
num_overlap_steps = schedule['overlap'][rank] + schedule['1b1overlap'][rank] num_overlap_steps = schedule['overlap'][rank] + schedule['1b1overlap'][rank]
...@@ -934,161 +933,130 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -934,161 +933,130 @@ def forward_backward_pipelining_with_cutinhalf(
if not only_bwd: if not only_bwd:
def pp_pre_forward(): def pp_pre_forward():
nonlocal fwd_wait_handles_recv nonlocal fwd_wait_recv_handles
if fwd_wait_handles_recv is not None: # wait input for current step
for req, req_handle in fwd_wait_handles_recv.items(): wait_comm_handles(fwd_wait_recv_handles[fwd_model_chunk_id])
req_handle.wait()
fwd_wait_handles_recv = None
def pp_post_forward(output_tensor): def pp_post_forward(output_tensor):
nonlocal cur_fwd_chunk_microbatch nonlocal cur_fwd_chunk_microbatch
nonlocal num_chunk_max_microbatch nonlocal num_chunk_max_microbatch
nonlocal fwd_wait_handles nonlocal fwd_wait_send_handles
nonlocal fwd_wait_handles_slave_chunk nonlocal fwd_wait_send_recv_handles
nonlocal firstFB_no_overlp_handle
if fwd_model_chunk_id == master_chunk_id: if fwd_model_chunk_id == master_chunk_id:
fwd_send_only = False fwd_send_only = False
else: else:
fwd_send_only = (cur_fwd_chunk_microbatch[master_chunk_id] == num_chunk_max_microbatch[master_chunk_id]) fwd_send_only = (cur_fwd_chunk_microbatch[master_chunk_id] == num_chunk_max_microbatch[master_chunk_id])
# 同步上个阶段最后一个slave前向send
if fwd_wait_handles_slave_chunk is not None:
for req, req_handle in fwd_wait_handles_slave_chunk.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles_slave_chunk = None
if not forward_only:
deallocate_output_tensor(
output_tensor_slave_chunk, config.deallocate_pipeline_outputs)
if fwd_send_only: if fwd_send_only:
input_tensor = None fwd_wait_send_handles[fwd_model_chunk_id] = send_forward(output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
fwd_wait_handles = send_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
else: else:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
if not forward_only: if not forward_only:
input_tensor = output_tensor.detach() input_tensor = output_tensor.detach()
input_tensor.requires_grad = True input_tensor.requires_grad = True
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
else: else:
input_tensor = output_tensor input_tensor = output_tensor
else: else:
input_tensor, fwd_wait_handles = send_forward_recv_slave_forward( input_tensor, fwd_wait_send_recv_handles = send_forward_recv_slave_forward(
output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True) output_tensor, tensor_shape, config, fwd_model_chunk_id, async_op=True)
if not forward_only and firstFB_no_overlp_handle is not None: input_tensors[1 - fwd_model_chunk_id].append(input_tensor)
for req, req_handle in firstFB_no_overlp_handle.items():
if req_handle is not None:
req_handle.wait()
firstFB_no_overlp_handle = None
return input_tensor return output_tensor
def pp_pre_backward(): def pp_pre_backward():
nonlocal bwd_wait_handles nonlocal bwd_wait_send_recv_handles
if not forward_only: if not forward_only:
if bwd_wait_handles is not None: wait_comm_handles(bwd_wait_send_recv_handles)
for _, req_handle in bwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
def pp_post_backward(input_tensor_grad): def pp_post_backward(input_tensor_grad):
nonlocal fwd_wait_handles nonlocal fwd_wait_send_handles
nonlocal bwd_wait_handles nonlocal fwd_wait_send_recv_handles
nonlocal bwd_wait_send_recv_handles
# Check whether the forward data transmission is completed.
wait_comm_handles(fwd_wait_send_handles[fwd_model_chunk_id])
wait_comm_handles(fwd_wait_send_recv_handles)
if fwd_wait_handles is not None:
for _, req_handle in fwd_wait_handles.items():
if req_handle is not None:
req_handle.wait()
fwd_wait_handles = None
if not forward_only: if not forward_only:
deallocate_output_tensor( deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
output_tensor, config.deallocate_pipeline_outputs)
if not forward_only: if not forward_only:
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad = input_tensor_grad output_tensor_grad = input_tensor_grad
else: else:
output_tensor_grad, bwd_wait_handles = send_backward_recv_slave_backward( output_tensor_grad, bwd_wait_send_recv_handles = send_backward_recv_slave_backward(
input_tensor_grad, input_tensor_grad,
tensor_shape, tensor_shape,
config, config,
fwd_model_chunk_id, fwd_model_chunk_id,
async_op=True async_op=True
) )
else: output_tensor_grads[fwd_model_chunk_id].append(output_tensor_grad)
output_tensor_grad = None
return output_tensor_grad return input_tensor_grad
# forward # forward
pp_pre_forward() pp_pre_forward()
output_tensor = forward_step_helper( output_tensor = forward_step_helper(
input_tensor,
fwd_model_chunk_id, fwd_model_chunk_id,
cur_fwd_chunk_microbatch[fwd_model_chunk_id], cur_fwd_chunk_microbatch[fwd_model_chunk_id],
is_first_microbatch=False is_first_microbatch=False
) )
cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1 cur_fwd_chunk_microbatch[fwd_model_chunk_id] += 1
input_tensor = pp_post_forward(output_tensor) output_tensor = pp_post_forward(output_tensor)
# backward # backward
pp_pre_backward() pp_pre_backward()
if not forward_only: if not forward_only:
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1] try:
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0) input_tensor_grad = backward_step_helper(bwd_model_chunk_id)
input_tensor_grad = backward_step_helper(input_tensor_bwd, output_tensor_bwd, output_tensor_grad_bwd) except Exception as e:
print(f"step_id: {step_id}, rank: {torch.distributed.get_rank()}, bwd_model_chunk_id: {bwd_model_chunk_id}", flush=True)
raise Exception(f"{e}")
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1 cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
else: else:
input_tensor_grad = None input_tensor_grad = None
output_tensor_grad_bwd = pp_post_backward(input_tensor_grad) _ = pp_post_backward(input_tensor_grad)
# only run backward # only run backward
else: else:
if bwd_model_chunk_id == slave_chunk_id and cur_fwd_chunk_microbatch[slave_chunk_id] < num_chunk_max_microbatch[slave_chunk_id]: if bwd_model_chunk_id == slave_chunk_id and cur_fwd_chunk_microbatch[slave_chunk_id] < num_chunk_max_microbatch[slave_chunk_id]:
input_tensor, fwd_wait_handles_recv = recv_forward( input_tensor, fwd_wait_recv_handles[slave_chunk_id] = recv_forward(tensor_shape, config, slave_chunk_id, async_op=True)
tensor_shape, config, slave_chunk_id, async_op=True) input_tensors[slave_chunk_id].append(input_tensor)
if not forward_only: if not forward_only:
if bwd_wait_handles is not None: wait_comm_handles(bwd_wait_send_handles[1 - bwd_model_chunk_id])
for req, req_handle in bwd_wait_handles.items(): wait_comm_handles(bwd_wait_send_recv_handles)
if req_handle is not None:
req_handle.wait()
bwd_wait_handles = None
input_tensor_bwd = input_tensors[bwd_model_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[bwd_model_chunk_id].pop(0)
input_tensor_grad = backward_step_helper( input_tensor_grad = backward_step_helper(
input_tensor_bwd, bwd_model_chunk_id,
output_tensor_bwd,
output_tensor_grad_bwd,
bwd_model_chunk_id=bwd_model_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[bwd_model_chunk_id] bwd_cur_microbatch=cur_bwd_chunk_microbatch[bwd_model_chunk_id]
) )
cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1 cur_bwd_chunk_microbatch[bwd_model_chunk_id] += 1
if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id: if parallel_state.is_pipeline_last_stage() and fwd_model_chunk_id == master_chunk_id:
output_tensor_grad_bwd = input_tensor_grad output_tensor_grad = input_tensor_grad
else: else:
if step_id == num_overlap_steps - 1: if step_id == num_overlap_steps - 1:
bwd_wait_handles = send_backward( bwd_wait_send_handles[bwd_model_chunk_id] = send_backward(
input_tensor_grad, input_tensor_grad,
tensor_shape, tensor_shape,
config, config,
bwd_model_chunk_id, bwd_model_chunk_id,
) )
output_tensor_grad = None
else: else:
# send_backward_recv_slave_backward # send_backward_recv_slave_backward
output_tensor_grad_bwd, bwd_wait_handles = send_backward_recv_slave_backward( output_tensor_grad, bwd_wait_send_recv_handles = send_backward_recv_slave_backward(
input_tensor_grad, input_tensor_grad,
tensor_shape, tensor_shape,
config, config,
fwd_model_chunk_id fwd_model_chunk_id
) )
output_tensor_grads[1 - bwd_model_chunk_id].append(output_tensor_grad)
# swap fwd & bwd chunks # swap fwd & bwd chunks
fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id fwd_model_chunk_id, bwd_model_chunk_id = bwd_model_chunk_id, fwd_model_chunk_id
...@@ -1102,16 +1070,11 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1102,16 +1070,11 @@ def forward_backward_pipelining_with_cutinhalf(
# Run cooldown phases # Run cooldown phases
if not forward_only: if not forward_only:
for i in range(schedule['cooldown'][rank][0]): for i in range(schedule['cooldown'][rank][0]):
output_tensor_grad_bwd, _ = recv_backward(tensor_shape, config, master_chunk_id) output_tensor_grad, _ = recv_backward(tensor_shape, config, master_chunk_id)
output_tensor_grads[master_chunk_id].append(output_tensor_grad)
input_tensor_bwd = input_tensors[master_chunk_id].pop(0)[1]
output_tensor_bwd = output_tensors[master_chunk_id].pop(0)
input_tensor_grad = backward_step_helper( input_tensor_grad = backward_step_helper(
input_tensor_bwd, master_chunk_id,
output_tensor_bwd,
output_tensor_grad_bwd,
bwd_model_chunk_id=master_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[master_chunk_id] bwd_cur_microbatch=cur_bwd_chunk_microbatch[master_chunk_id]
) )
cur_bwd_chunk_microbatch[master_chunk_id] += 1 cur_bwd_chunk_microbatch[master_chunk_id] += 1
......
import gc import gc
import sys import sys
import time
from functools import wraps from functools import wraps
import torch.distributed import torch.distributed
...@@ -14,7 +15,10 @@ from megatron.core.distributed import DistributedDataParallel as DDP ...@@ -14,7 +15,10 @@ from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
from megatron.core.distributed import finalize_model_grads from megatron.core.distributed import finalize_model_grads
from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.rerun_state_machine import (
get_rerun_state_machine,
RerunMode,
)
from megatron.training.initialize import write_args_to_tensorboard from megatron.training.initialize import write_args_to_tensorboard
from megatron.core.num_microbatches_calculator import ( from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size, get_current_global_batch_size,
...@@ -30,6 +34,8 @@ from megatron.training.utils import ( ...@@ -30,6 +34,8 @@ from megatron.training.utils import (
logical_and_across_model_parallel_group, logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group, reduce_max_stat_across_model_parallel_group,
unwrap_model, unwrap_model,
is_rank0,
is_last_rank,
) )
from megatron.training.global_vars import ( from megatron.training.global_vars import (
get_args, get_args,
...@@ -55,6 +61,7 @@ from megatron.training.training import ( ...@@ -55,6 +61,7 @@ from megatron.training.training import (
cuda_graph_capture, cuda_graph_capture,
cuda_graph_set_manual_hooks, cuda_graph_set_manual_hooks,
dummy_train_step, dummy_train_step,
_TRAIN_START_TIME,
) )
from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.pipeline_parallel import get_forward_backward_func
...@@ -560,3 +567,133 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -560,3 +567,133 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
sys.exit(exit_code) sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far return iteration, num_floating_point_operations_so_far
def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
config,
verbose=False,
non_loss_data_func=None):
"""Evaluation."""
args = get_args()
timers = get_timers()
timers('evaluate', log_level=0).start(barrier=True)
if args.vision_pretraining and args.vision_pretraining_type == "dino":
from megatron.legacy.model.vision.knn_monitor import compute_feature_bank
compute_feature_bank(model)
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
# Disable result validation during evaluation
rerun_state_machine = get_rerun_state_machine()
rerun_mode = rerun_state_machine.get_mode()
rerun_state_machine.set_mode(RerunMode.DISABLED)
total_loss_dict = {}
# make validation batch size independent from training batch size
eval_batch_size = args.global_batch_size
eval_num_microbatches = eval_batch_size // \
(args.micro_batch_size * args.data_parallel_size)
with torch.no_grad():
iteration = 0
if verbose:
print_rank_0(f'Evaluating on {args.eval_iters * eval_batch_size} samples')
while iteration < args.eval_iters:
iteration += 1
if verbose:
print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}')
forward_backward_func = get_forward_backward_func()
# Don't care about timing during evaluation
config.timers = None
ft_integration.on_eval_step_start()
loss_dicts = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=eval_num_microbatches,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=True)
ft_integration.on_eval_step_end()
config.timers = get_timers()
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
if args.schedule_method == 'dualpipev':
is_last_stage = mpu.is_pipeline_first_stage(ignore_virtual=True)
else:
is_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=True)
if is_last_stage:
# Reduce across processes.
for loss_dict in loss_dicts:
for key in loss_dict:
if key not in total_loss_dict:
total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda()
val = loss_dict[key]
if isinstance(val, tuple) or isinstance(val, list):
total_loss_dict[key][0] += val[0]
total_loss_dict[key][1] += val[1]
else:
total_loss_dict[key][0] += val
total_loss_dict[key][1] += 1
args.consumed_valid_samples += eval_batch_size
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.tensor(
[train_time > args.exit_duration_in_mins],
dtype=torch.int, device='cuda')
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
rerun_state_machine.set_mode(rerun_mode)
print_rank_0('Exiting during evaluation, timelimit reached')
return None, None, True
is_last_rank_func = is_rank0 if args.schedule_method == 'dualpipev' else is_last_rank
collected_non_loss_data = None
if non_loss_data_func is not None:
collected_non_loss_data = non_loss_data_func(model)
elif process_non_loss_data_func is not None and is_last_rank_func():
collected_non_loss_data = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
model=model,
num_microbatches=get_num_microbatches(),
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=True,
collect_non_loss_data=True)
# Move model back to the train mode.
for model_module in model:
model_module.train()
for key in total_loss_dict:
numerator, denominator = total_loss_dict[key]
total_loss_dict[key] = numerator / denominator
timers('evaluate').stop()
timers.log(['evaluate'])
rerun_state_machine.set_mode(rerun_mode)
rerun_state_machine.set_mode(rerun_mode)
return total_loss_dict, collected_non_loss_data, False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment