"tools/imglab/src/image_dataset_metadata.cpp" did not exist on "fbce4cd61db822929bd5229cf260cb8e0b06aecf"
Commit 08efd4ec authored by dongcl's avatar dongcl
Browse files

rewrite schedules based on megatron a73b4d2d4a993e9bea97fdebb841a393eb4ad5e7

parent c964fcca
...@@ -13,6 +13,8 @@ from megatron.core.utils import ( ...@@ -13,6 +13,8 @@ from megatron.core.utils import (
get_model_config, get_model_config,
get_model_type, get_model_type,
get_model_xattn, get_model_xattn,
nvtx_range_pop,
nvtx_range_push,
) )
from megatron.core.pipeline_parallel.schedules import ( from megatron.core.pipeline_parallel.schedules import (
forward_step, forward_step,
...@@ -430,7 +432,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -430,7 +432,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# forward step # forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
...@@ -458,6 +460,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -458,6 +460,7 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk(virtual_microbatch_id), is_first_microbatch_for_model_chunk(virtual_microbatch_id),
), ),
current_microbatch=microbatch_id, current_microbatch=microbatch_id,
vp_stage=model_chunk_id,
) )
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
...@@ -477,7 +480,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -477,7 +480,6 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling (run set_virtual_pipeline_model_parallel_rank() before calling
backward_step()).""" backward_step())."""
nonlocal output_tensor_grads # TODO(dongcl)
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
...@@ -489,7 +491,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -489,7 +491,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
# pylint: disable=E0606 # pylint: disable=E0606
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
...@@ -728,9 +730,14 @@ def forward_backward_pipelining_with_interleaving( ...@@ -728,9 +730,14 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = post_backward(input_tensor_grad) input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad return output_tensor, input_tensor_grad
is_vp_first_stage = partial(parallel_state.is_pipeline_first_stage, ignore_virtual=False)
is_vp_last_stage = partial(parallel_state.is_pipeline_last_stage, ignore_virtual=False)
# Run warmup forward passes. # Run warmup forward passes.
nvtx_range_push(suffix="warmup")
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config)) input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, config, is_vp_first_stage())
)
fwd_wait_handles = None fwd_wait_handles = None
fwd_wait_recv_handles = None fwd_wait_recv_handles = None
...@@ -760,7 +767,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -760,7 +767,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False) and k != 0: if not is_vp_first_stage(vp_stage=cur_model_chunk_id) and k != 0:
assert recv_prev_wait_handles, ( assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, iteration {k},' f'pp rank {pipeline_parallel_rank}, iteration {k},'
'should have registered recv handle' 'should have registered recv handle'
...@@ -807,7 +814,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -807,7 +814,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if is_vp_last_stage(vp_stage=cur_model_chunk_id):
output_tensor = None output_tensor = None
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
...@@ -910,8 +917,10 @@ def forward_backward_pipelining_with_interleaving( ...@@ -910,8 +917,10 @@ def forward_backward_pipelining_with_interleaving(
if recv_next: if recv_next:
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1]) output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
nvtx_range_pop(suffix="warmup")
# Run 1F1B in steady state. # Run 1F1B in steady state.
nvtx_range_push(suffix="steady")
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
...@@ -928,14 +937,14 @@ def forward_backward_pipelining_with_interleaving( ...@@ -928,14 +937,14 @@ def forward_backward_pipelining_with_interleaving(
else: else:
checkpoint_activations_microbatch = None checkpoint_activations_microbatch = None
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm: if config.overlap_p2p_comm:
backward_k = k
# output send / receive sync # output send / receive sync
def pp_pre_forward(): def pp_pre_forward():
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False): nonlocal recv_prev_wait_handles
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not is_vp_first_stage(vp_stage=cur_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_prev_wait_handles, ( assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, ' f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
...@@ -956,8 +965,14 @@ def forward_backward_pipelining_with_interleaving( ...@@ -956,8 +965,14 @@ def forward_backward_pipelining_with_interleaving(
nonlocal fwd_recv_buffer nonlocal fwd_recv_buffer
nonlocal fwd_wait_handles nonlocal fwd_wait_handles
nonlocal recv_prev_wait_handles nonlocal recv_prev_wait_handles
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send. # Last virtual stage no activation tensor to send.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None output_tensor = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1002,10 +1017,14 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1002,10 +1017,14 @@ def forward_backward_pipelining_with_interleaving(
return output_tensor return output_tensor
backward_k = k
# grad send receive sync # grad send receive sync
def pp_pre_backward(): def pp_pre_backward():
nonlocal recv_next_wait_handles nonlocal recv_next_wait_handles
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False):
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if not is_vp_last_stage(vp_stage=backward_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, ( assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, ' f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
...@@ -1023,8 +1042,13 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1023,8 +1042,13 @@ def forward_backward_pipelining_with_interleaving(
nonlocal send_prev_wait_handle nonlocal send_prev_wait_handle
nonlocal bwd_wait_handles nonlocal bwd_wait_handles
nonlocal recv_next_wait_handles nonlocal recv_next_wait_handles
nonlocal bwd_recv_buffer
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send. # First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None input_tensor_grad = None
recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage( recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1044,16 +1068,13 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1044,16 +1068,13 @@ def forward_backward_pipelining_with_interleaving(
send_prev_wait_handle.wait() send_prev_wait_handle.wait()
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
send_prev_wait_handle = ( send_prev_wait_handle = (
bwd_wait_handles.pop("send_prev") bwd_wait_handles.pop("send_prev") if "send_prev" in bwd_wait_handles else None
if "send_prev" in bwd_wait_handles
else None
) )
if "recv_next" in bwd_wait_handles: if "recv_next" in bwd_wait_handles:
recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next")) recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next"))
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
if recv_next: if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
bwd_recv_buffer[backward_k % bwd_recv_buffer_size] bwd_recv_buffer[backward_k % bwd_recv_buffer_size]
...@@ -1088,12 +1109,12 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1088,12 +1109,12 @@ def forward_backward_pipelining_with_interleaving(
# otherwise set tensor to None. # otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None input_tensor_grad = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1135,8 +1156,10 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1135,8 +1156,10 @@ def forward_backward_pipelining_with_interleaving(
print_rank_0(f"rank first. 1F1B in steady state end") print_rank_0(f"rank first. 1F1B in steady state end")
print_rank_4(f"rank last. 1F1B in steady state end") print_rank_4(f"rank last. 1F1B in steady state end")
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
nvtx_range_pop(suffix="steady")
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline).
nvtx_range_push(suffix="cooldown")
if not forward_only: if not forward_only:
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for bwd_wait_handle in bwd_wait_handles.values(): for bwd_wait_handle in bwd_wait_handles.values():
...@@ -1144,12 +1167,14 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1144,12 +1167,14 @@ def forward_backward_pipelining_with_interleaving(
if are_all_microbatches_in_warmup: if are_all_microbatches_in_warmup:
output_tensor_grads[num_model_chunks - 1].append( output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config) p2p_communication.recv_backward(
tensor_shape, config=config, is_last_stage=is_vp_last_stage()
)
) )
for k in range(num_microbatches_remaining, total_num_microbatches): for k in range(num_microbatches_remaining, total_num_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=False) cur_model_chunk_id = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False) and k != 0: if not is_vp_last_stage(vp_stage=cur_model_chunk_id) and k != 0:
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, ( assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, backward iteration {k}, ' f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
...@@ -1189,7 +1214,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1189,7 +1214,7 @@ def forward_backward_pipelining_with_interleaving(
_, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k) _, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k)
# First virtual stage no activation gradient tensor to send. # First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if is_vp_first_stage(vp_stage=cur_model_chunk_id):
input_tensor_grad = None input_tensor_grad = None
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
...@@ -1246,7 +1271,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1246,7 +1271,9 @@ def forward_backward_pipelining_with_interleaving(
if model_chunk_id not in synchronized_model_chunks: if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters()) config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
nvtx_range_pop(suffix="cooldown")
nvtx_range_push(suffix="misc")
assert ( assert (
not recv_prev_wait_handles not recv_prev_wait_handles
), 'recv_prev_wait_handles should be cleared at the end of a step' ), 'recv_prev_wait_handles should be cleared at the end of a step'
...@@ -1276,6 +1303,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1276,6 +1303,7 @@ def forward_backward_pipelining_with_interleaving(
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph: if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs() create_cudagraphs()
nvtx_range_pop(suffix="misc")
return forward_data_store return forward_data_store
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment