Commit 08efd4ec authored by dongcl's avatar dongcl
Browse files

rewrite schedules based on megatron a73b4d2d4a993e9bea97fdebb841a393eb4ad5e7

parent c964fcca
......@@ -13,6 +13,8 @@ from megatron.core.utils import (
get_model_config,
get_model_type,
get_model_xattn,
nvtx_range_pop,
nvtx_range_push,
)
from megatron.core.pipeline_parallel.schedules import (
forward_step,
......@@ -430,7 +432,7 @@ def forward_backward_pipelining_with_interleaving(
)
# forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
......@@ -458,6 +460,7 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk(virtual_microbatch_id),
),
current_microbatch=microbatch_id,
vp_stage=model_chunk_id,
)
output_tensors[model_chunk_id].append(output_tensor)
......@@ -477,7 +480,6 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal output_tensor_grads # TODO(dongcl)
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
......@@ -489,7 +491,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks.add(model_chunk_id)
# pylint: disable=E0606
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
......@@ -728,9 +730,14 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad
is_vp_first_stage = partial(parallel_state.is_pipeline_first_stage, ignore_virtual=False)
is_vp_last_stage = partial(parallel_state.is_pipeline_last_stage, ignore_virtual=False)
# Run warmup forward passes.
nvtx_range_push(suffix="warmup")
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, config, is_vp_first_stage())
)
fwd_wait_handles = None
fwd_wait_recv_handles = None
......@@ -760,7 +767,7 @@ def forward_backward_pipelining_with_interleaving(
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm_warmup_flush:
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False) and k != 0:
if not is_vp_first_stage(vp_stage=cur_model_chunk_id) and k != 0:
assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, iteration {k},'
'should have registered recv handle'
......@@ -807,7 +814,7 @@ def forward_backward_pipelining_with_interleaving(
)
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if is_vp_last_stage(vp_stage=cur_model_chunk_id):
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
......@@ -910,8 +917,10 @@ def forward_backward_pipelining_with_interleaving(
if recv_next:
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
nvtx_range_pop(suffix="warmup")
# Run 1F1B in steady state.
nvtx_range_push(suffix="steady")
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
......@@ -928,14 +937,14 @@ def forward_backward_pipelining_with_interleaving(
else:
checkpoint_activations_microbatch = None
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm:
backward_k = k
# output send / receive sync
def pp_pre_forward():
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False):
nonlocal recv_prev_wait_handles
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not is_vp_first_stage(vp_stage=cur_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush:
assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
......@@ -956,8 +965,14 @@ def forward_backward_pipelining_with_interleaving(
nonlocal fwd_recv_buffer
nonlocal fwd_wait_handles
nonlocal recv_prev_wait_handles
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1002,10 +1017,14 @@ def forward_backward_pipelining_with_interleaving(
return output_tensor
backward_k = k
# grad send receive sync
def pp_pre_backward():
nonlocal recv_next_wait_handles
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False):
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if not is_vp_last_stage(vp_stage=backward_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
......@@ -1023,8 +1042,13 @@ def forward_backward_pipelining_with_interleaving(
nonlocal send_prev_wait_handle
nonlocal bwd_wait_handles
nonlocal recv_next_wait_handles
nonlocal bwd_recv_buffer
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None
recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1044,16 +1068,13 @@ def forward_backward_pipelining_with_interleaving(
send_prev_wait_handle.wait()
if bwd_wait_handles is not None:
send_prev_wait_handle = (
bwd_wait_handles.pop("send_prev")
if "send_prev" in bwd_wait_handles
else None
bwd_wait_handles.pop("send_prev") if "send_prev" in bwd_wait_handles else None
)
if "recv_next" in bwd_wait_handles:
recv_next_wait_handles.append(bwd_wait_handles.pop("recv_next"))
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(
bwd_recv_buffer[backward_k % bwd_recv_buffer_size]
......@@ -1088,12 +1109,12 @@ def forward_backward_pipelining_with_interleaving(
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1135,8 +1156,10 @@ def forward_backward_pipelining_with_interleaving(
print_rank_0(f"rank first. 1F1B in steady state end")
print_rank_4(f"rank last. 1F1B in steady state end")
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
nvtx_range_pop(suffix="steady")
# Run cooldown backward passes (flush out pipeline).
nvtx_range_push(suffix="cooldown")
if not forward_only:
if bwd_wait_handles is not None:
for bwd_wait_handle in bwd_wait_handles.values():
......@@ -1144,12 +1167,14 @@ def forward_backward_pipelining_with_interleaving(
if are_all_microbatches_in_warmup:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
p2p_communication.recv_backward(
tensor_shape, config=config, is_last_stage=is_vp_last_stage()
)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False) and k != 0:
if not is_vp_last_stage(vp_stage=cur_model_chunk_id) and k != 0:
if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
......@@ -1189,7 +1214,7 @@ def forward_backward_pipelining_with_interleaving(
_, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k)
# First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if is_vp_first_stage(vp_stage=cur_model_chunk_id):
input_tensor_grad = None
if config.overlap_p2p_comm_warmup_flush:
......@@ -1246,7 +1271,9 @@ def forward_backward_pipelining_with_interleaving(
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
nvtx_range_pop(suffix="cooldown")
nvtx_range_push(suffix="misc")
assert (
not recv_prev_wait_handles
), 'recv_prev_wait_handles should be cleared at the end of a step'
......@@ -1276,6 +1303,7 @@ def forward_backward_pipelining_with_interleaving(
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
nvtx_range_pop(suffix="misc")
return forward_data_store
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment