Commit e103a256 authored by dongcl's avatar dongcl
Browse files

patch for megatron commit 0595ef2b0c93f8d61f473c9f99f9ff73803ff919

parent ade7b0dc
...@@ -10,6 +10,7 @@ from packaging.version import Version as PkgVersion ...@@ -10,6 +10,7 @@ from packaging.version import Version as PkgVersion
from megatron.training import get_args from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.tensor_parallel import get_cuda_rng_tracker from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.extensions.transformer_engine import TEDotProductAttention
...@@ -20,7 +21,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore ...@@ -20,7 +21,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group, get_context_parallel_group,
get_hierarchical_context_parallel_groups, get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
...@@ -69,6 +69,8 @@ class TELinear(MegatronCoreTELinear): ...@@ -69,6 +69,8 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation: bool, skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False, is_expert: bool = False,
symmetric_ar_type: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
args = get_args() args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
...@@ -90,6 +92,8 @@ class TELinear(MegatronCoreTELinear): ...@@ -90,6 +92,8 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation=skip_weight_param_allocation, skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
is_expert=is_expert, is_expert=is_expert,
symmetric_ar_type=symmetric_ar_type,
tp_group=tp_group,
) )
def backward_dw(self): def backward_dw(self):
...@@ -118,6 +122,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -118,6 +122,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert: bool, is_expert: bool,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None, tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
args = get_args() args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
...@@ -139,6 +144,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea ...@@ -139,6 +144,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert=is_expert, is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation, skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name, tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
) )
def backward_dw(self): def backward_dw(self):
......
...@@ -282,6 +282,7 @@ def forward_backward_step( ...@@ -282,6 +282,7 @@ def forward_backward_step(
checkpoint_activations_microbatch=None, checkpoint_activations_microbatch=None,
is_first_microbatch=False, is_first_microbatch=False,
current_microbatch=None, current_microbatch=None,
vp_stage=None,
encoder_decoder_xattn=False, encoder_decoder_xattn=False,
): ):
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -345,6 +346,8 @@ def forward_backward_step( ...@@ -345,6 +346,8 @@ def forward_backward_step(
Whether it is the first microbatch. Defaults to False. Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional): current_microbatch (int, optional):
The current microbatch. Defaults to None. The current microbatch. Defaults to None.
vp_stage (int, optional):
The virtual pipeline stage. Defaults to None.
Returns: Returns:
Tensor or list[Tensor]: The output object(s) from the forward step. Tensor or list[Tensor]: The output object(s) from the forward step.
...@@ -435,13 +438,19 @@ def forward_backward_step( ...@@ -435,13 +438,19 @@ def forward_backward_step(
num_tokens = None num_tokens = None
if f_model: if f_model:
with f_context: with f_context:
model_vp_stage = getattr(f_model, "vp_stage", None)
if vp_stage is not None and model_vp_stage is not None:
assert (
vp_stage == model_vp_stage
), f"vp_stage ({vp_stage}) doesn't match model_vp_stage ({model_vp_stage})"
num_tokens = torch.tensor(0, dtype=torch.int) num_tokens = torch.tensor(0, dtype=torch.int)
args = get_args() args = get_args()
is_last_stage = False is_last_stage = False
if args.schedule_method == "dualpipev": if args.schedule_method == "dualpipev":
is_last_stage = parallel_state.is_pipeline_first_stage() and get_dualpipe_chunk() == 1 is_last_stage = parallel_state.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
else: else:
is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False) is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage)
if is_last_stage: if is_last_stage:
if not collect_non_loss_data: if not collect_non_loss_data:
loss_node = ScheduleNode( loss_node = ScheduleNode(
......
...@@ -1111,11 +1111,9 @@ def forward_backward_pipelining_with_cutinhalf( ...@@ -1111,11 +1111,9 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad, _ = recv_backward(tensor_shape, config, master_chunk_id) output_tensor_grad, _ = recv_backward(tensor_shape, config, master_chunk_id)
output_tensor_grads[master_chunk_id].append(output_tensor_grad) output_tensor_grads[master_chunk_id].append(output_tensor_grad)
input_tensor_grad = backward_step_helper( _, input_tensor_grad = forward_backward_helper_wrapper(
master_chunk_id, bwd_model_chunk_id=master_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[master_chunk_id]
) )
cur_bwd_chunk_microbatch[master_chunk_id] += 1
_ = send_backward( _ = send_backward(
input_tensor_grad, input_tensor_grad,
......
import contextlib import contextlib
from typing import Iterator, List, Union from functools import partial
from typing import Callable, Iterator, List, Optional, Union
import torch import torch
...@@ -12,6 +13,8 @@ from megatron.core.utils import ( ...@@ -12,6 +13,8 @@ from megatron.core.utils import (
get_model_config, get_model_config,
get_model_type, get_model_type,
get_model_xattn, get_model_xattn,
nvtx_range_pop,
nvtx_range_push,
) )
from megatron.core.pipeline_parallel.schedules import ( from megatron.core.pipeline_parallel.schedules import (
forward_step, forward_step,
...@@ -82,10 +85,11 @@ def forward_backward_pipelining_with_interleaving( ...@@ -82,10 +85,11 @@ def forward_backward_pipelining_with_interleaving(
num_microbatches: int, num_microbatches: int,
seq_length: int, seq_length: int,
micro_batch_size: int, micro_batch_size: int,
decoder_seq_length: int = None, decoder_seq_length: Optional[int] = None,
forward_only: bool = False, forward_only: bool = False,
collect_non_loss_data: bool = False, collect_non_loss_data: bool = False,
first_val_step: bool = None, first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None, # unused
): ):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
...@@ -106,6 +110,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -106,6 +110,9 @@ def forward_backward_pipelining_with_interleaving(
assert isinstance( assert isinstance(
data_iterator, list data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator" ), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
assert (
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"
config = get_model_config(model[0]) config = get_model_config(model[0])
...@@ -373,11 +380,8 @@ def forward_backward_pipelining_with_interleaving( ...@@ -373,11 +380,8 @@ def forward_backward_pipelining_with_interleaving(
def forward_step_helper( def forward_step_helper(
virtual_microbatch_id, microbatch_id, checkpoint_activations_microbatch virtual_microbatch_id, microbatch_id, checkpoint_activations_microbatch
): ):
"""Helper method to run forward step with model split into chunks """Helper method to run forward step with model split into chunks"""
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch param synchronization for next model chunk # launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute. # Note: Asynchronous communication tends to slow down compute.
...@@ -399,7 +403,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -399,7 +403,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# forward step # forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
...@@ -427,6 +431,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -427,6 +431,7 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk(virtual_microbatch_id), is_first_microbatch_for_model_chunk(virtual_microbatch_id),
), ),
current_microbatch=microbatch_id, current_microbatch=microbatch_id,
vp_stage=model_chunk_id,
) )
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
...@@ -443,13 +448,8 @@ def forward_backward_pipelining_with_interleaving( ...@@ -443,13 +448,8 @@ def forward_backward_pipelining_with_interleaving(
return output_tensor return output_tensor
def backward_step_helper(virtual_microbatch_id): def backward_step_helper(virtual_microbatch_id):
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks"""
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal output_tensor_grads
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch grad synchronization (default) # launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk( if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(
...@@ -459,7 +459,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -459,7 +459,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
# pylint: disable=E0606 # pylint: disable=E0606
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
...@@ -509,7 +509,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -509,7 +509,7 @@ def forward_backward_pipelining_with_interleaving(
if f_virtual_microbatch_id is not None: if f_virtual_microbatch_id is not None:
model_chunk_id = get_model_chunk_id(f_virtual_microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(f_virtual_microbatch_id, forward=True)
f_model_chunk_id = model_chunk_id f_model_chunk_id = model_chunk_id
f_context = VppContextManager(f_model_chunk_id) # f_context = VppContextManager(f_model_chunk_id)
with f_context: with f_context:
# launch param synchronization for next model chunk # launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute. # Note: Asynchronous communication tends to slow down compute.
...@@ -533,7 +533,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -533,7 +533,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# forward step # forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
...@@ -556,7 +556,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -556,7 +556,7 @@ def forward_backward_pipelining_with_interleaving(
if b_virtual_microbatch_id is not None: if b_virtual_microbatch_id is not None:
model_chunk_id = get_model_chunk_id(b_virtual_microbatch_id, forward=False) model_chunk_id = get_model_chunk_id(b_virtual_microbatch_id, forward=False)
b_model_chunk_id = model_chunk_id b_model_chunk_id = model_chunk_id
b_context = VppContextManager(b_model_chunk_id) # b_context = VppContextManager(b_model_chunk_id)
with b_context: with b_context:
# launch grad synchronization (default) # launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk( if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(
...@@ -565,7 +565,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -565,7 +565,7 @@ def forward_backward_pipelining_with_interleaving(
enable_grad_sync() enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_tensor_grads[model_chunk_id].append(None)
b_input_tensor = input_tensors[model_chunk_id].pop(0) b_input_tensor = input_tensors[model_chunk_id].pop(0)
...@@ -602,6 +602,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -602,6 +602,7 @@ def forward_backward_pipelining_with_interleaving(
), ),
), ),
current_microbatch=f_microbatch_id, current_microbatch=f_microbatch_id,
vp_stage=f_model_chunk_id,
) )
# forward post process # forward post process
...@@ -675,8 +676,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -675,8 +676,6 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = None input_tensor_grad = None
if f_virtual_microbatch_id is not None: if f_virtual_microbatch_id is not None:
# forward pass # forward pass
forward_model_chunk_id = get_model_chunk_id(f_virtual_microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if pre_forward is not None: if pre_forward is not None:
pre_forward() pre_forward()
...@@ -689,8 +688,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -689,8 +688,6 @@ def forward_backward_pipelining_with_interleaving(
if b_virtual_microbatch_id is not None: if b_virtual_microbatch_id is not None:
# Backward pass. # Backward pass.
backward_model_chunk_id = get_model_chunk_id(b_virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if pre_backward is not None: if pre_backward is not None:
pre_backward() pre_backward()
input_tensor_grad = backward_step_helper(b_virtual_microbatch_id) input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)
...@@ -698,9 +695,15 @@ def forward_backward_pipelining_with_interleaving( ...@@ -698,9 +695,15 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = post_backward(input_tensor_grad) input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad return output_tensor, input_tensor_grad
is_vp_first_stage = partial(parallel_state.is_pipeline_first_stage, ignore_virtual=False)
is_vp_last_stage = partial(parallel_state.is_pipeline_last_stage, ignore_virtual=False)
# Run warmup forward passes. # Run warmup forward passes.
nvtx_range_push(suffix="warmup")
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config)) input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, config, is_vp_first_stage(vp_stage=0))
)
fwd_wait_handles = None fwd_wait_handles = None
fwd_wait_recv_handles = None fwd_wait_recv_handles = None
...@@ -727,10 +730,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -727,10 +730,9 @@ def forward_backward_pipelining_with_interleaving(
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=True) cur_model_chunk_id = get_model_chunk_id(k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False) and k != 0: if not is_vp_first_stage(vp_stage=cur_model_chunk_id) and k != 0:
assert recv_prev_wait_handles, ( assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, iteration {k},' f'pp rank {pipeline_parallel_rank}, iteration {k},'
'should have registered recv handle' 'should have registered recv handle'
...@@ -777,7 +779,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -777,7 +779,7 @@ def forward_backward_pipelining_with_interleaving(
) )
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if is_vp_last_stage(vp_stage=cur_model_chunk_id):
output_tensor = None output_tensor = None
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
...@@ -880,8 +882,10 @@ def forward_backward_pipelining_with_interleaving( ...@@ -880,8 +882,10 @@ def forward_backward_pipelining_with_interleaving(
if recv_next: if recv_next:
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1]) output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
nvtx_range_pop(suffix="warmup")
# Run 1F1B in steady state. # Run 1F1B in steady state.
nvtx_range_push(suffix="steady")
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
...@@ -895,14 +899,15 @@ def forward_backward_pipelining_with_interleaving( ...@@ -895,14 +899,15 @@ def forward_backward_pipelining_with_interleaving(
else: else:
checkpoint_activations_microbatch = None checkpoint_activations_microbatch = None
microbatch_id = get_microbatch_id_in_model_chunk(forward_k, forward=True)
if config.overlap_p2p_comm: if config.overlap_p2p_comm:
# output send / receive sync
def pp_pre_forward(): def pp_pre_forward():
nonlocal recv_prev_wait_handles nonlocal recv_prev_wait_handles
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True) cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False): if not is_vp_first_stage(vp_stage=cur_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_prev_wait_handles, ( assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, ' f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
...@@ -917,7 +922,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -917,7 +922,6 @@ def forward_backward_pipelining_with_interleaving(
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# output async send / receive
def pp_post_forward(output_tensor): def pp_post_forward(output_tensor):
nonlocal send_next_wait_handle nonlocal send_next_wait_handle
nonlocal fwd_recv_buffer nonlocal fwd_recv_buffer
...@@ -927,10 +931,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -927,10 +931,9 @@ def forward_backward_pipelining_with_interleaving(
# Determine if current stage has anything to send in either direction, # Determine if current stage has anything to send in either direction,
# otherwise set tensor to None. # otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send. # Last virtual stage no activation tensor to send.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False): if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None output_tensor = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -963,8 +966,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -963,8 +966,6 @@ def forward_backward_pipelining_with_interleaving(
recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev")) recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev"))
# assert fwd_wait_handles is not None # assert fwd_wait_handles is not None
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev: if recv_prev:
input_tensors[next_forward_model_chunk_id].append( input_tensors[next_forward_model_chunk_id].append(
fwd_recv_buffer[forward_k % fwd_recv_buffer_size] fwd_recv_buffer[forward_k % fwd_recv_buffer_size]
...@@ -973,14 +974,13 @@ def forward_backward_pipelining_with_interleaving( ...@@ -973,14 +974,13 @@ def forward_backward_pipelining_with_interleaving(
return output_tensor return output_tensor
# Backward pass.
backward_k = k backward_k = k
# grad send receive sync # grad send receive sync
def pp_pre_backward(): def pp_pre_backward():
nonlocal recv_next_wait_handles nonlocal recv_next_wait_handles
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) if not is_vp_last_stage(vp_stage=backward_model_chunk_id):
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, ( assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, ' f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
...@@ -1000,11 +1000,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1000,11 +1000,9 @@ def forward_backward_pipelining_with_interleaving(
nonlocal recv_next_wait_handles nonlocal recv_next_wait_handles
nonlocal bwd_recv_buffer nonlocal bwd_recv_buffer
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send. # First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None input_tensor_grad = None
recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage( recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1036,6 +1034,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1036,6 +1034,7 @@ def forward_backward_pipelining_with_interleaving(
bwd_recv_buffer[backward_k % bwd_recv_buffer_size] bwd_recv_buffer[backward_k % bwd_recv_buffer_size]
) )
bwd_recv_buffer[(backward_k + 1) % bwd_recv_buffer_size] = None bwd_recv_buffer[(backward_k + 1) % bwd_recv_buffer_size] = None
return input_tensor_grad return input_tensor_grad
output_tensor, input_tensor_grad = forward_backward_helper_wrapper( output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
...@@ -1061,13 +1060,11 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1061,13 +1060,11 @@ def forward_backward_pipelining_with_interleaving(
# Determine if current stage has anything to send in either direction, # Determine if current stage has anything to send in either direction,
# otherwise set tensor to None. # otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) if is_vp_last_stage(vp_stage=forward_model_chunk_id):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
output_tensor = None output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) if is_vp_first_stage(vp_stage=backward_model_chunk_id):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
input_tensor_grad = None input_tensor_grad = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage( recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
...@@ -1104,8 +1101,11 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1104,8 +1101,11 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
nvtx_range_pop(suffix="steady")
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline) for the last model chunk.
nvtx_range_push(suffix="cooldown")
curr_vp_stage = config.virtual_pipeline_model_parallel_size - 1
if not forward_only: if not forward_only:
if bwd_wait_handles is not None: if bwd_wait_handles is not None:
for bwd_wait_handle in bwd_wait_handles.values(): for bwd_wait_handle in bwd_wait_handles.values():
...@@ -1113,12 +1113,15 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1113,12 +1113,15 @@ def forward_backward_pipelining_with_interleaving(
if are_all_microbatches_in_warmup: if are_all_microbatches_in_warmup:
output_tensor_grads[num_model_chunks - 1].append( output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config) p2p_communication.recv_backward(
tensor_shape,
config=config,
is_last_stage=is_vp_last_stage(vp_stage=curr_vp_stage),
)
) )
for k in range(num_microbatches_remaining, total_num_microbatches): for k in range(num_microbatches_remaining, total_num_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=False) cur_model_chunk_id = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id) if not is_vp_last_stage(vp_stage=cur_model_chunk_id) and k != 0:
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False) and k != 0:
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, ( assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, backward iteration {k}, ' f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
...@@ -1158,7 +1161,7 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1158,7 +1161,7 @@ def forward_backward_pipelining_with_interleaving(
_, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k) _, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k)
# First virtual stage no activation gradient tensor to send. # First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False): if is_vp_first_stage(vp_stage=cur_model_chunk_id):
input_tensor_grad = None input_tensor_grad = None
if config.overlap_p2p_comm_warmup_flush: if config.overlap_p2p_comm_warmup_flush:
...@@ -1215,7 +1218,9 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1215,7 +1218,9 @@ def forward_backward_pipelining_with_interleaving(
if model_chunk_id not in synchronized_model_chunks: if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters()) config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id) synchronized_model_chunks.add(model_chunk_id)
nvtx_range_pop(suffix="cooldown")
nvtx_range_push(suffix="misc")
assert ( assert (
not recv_prev_wait_handles not recv_prev_wait_handles
), 'recv_prev_wait_handles should be cleared at the end of a step' ), 'recv_prev_wait_handles should be cleared at the end of a step'
...@@ -1245,5 +1250,6 @@ def forward_backward_pipelining_with_interleaving( ...@@ -1245,5 +1250,6 @@ def forward_backward_pipelining_with_interleaving(
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph: if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs() create_cudagraphs()
nvtx_range_pop(suffix="misc")
return forward_data_store return forward_data_store
...@@ -87,6 +87,12 @@ def parallel_attention_init_wrapper(fn): ...@@ -87,6 +87,12 @@ def parallel_attention_init_wrapper(fn):
return wrapper return wrapper
class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_context=None, encoder_output=None, inference_context=None,
rotary_pos_emb=None, *, inference_params=None): rotary_pos_emb=None, *, inference_params=None):
......
This diff is collapsed.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT.""" """Pretrain GPT."""
import datetime
import os import os
import torch import torch
from functools import partial
from contextlib import nullcontext
import inspect
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from megatron.core import parallel_state
from megatron.training import get_args from megatron.training import get_args
from megatron.training import inprocess_restart
from megatron.training import print_rank_0 from megatron.training import print_rank_0
from megatron.training import get_timers from megatron.training import get_timers
from megatron.training import get_tokenizer from megatron.training import get_tokenizer
from megatron.core import mpu from megatron.core import mpu
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.enums import ModelType
from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain from megatron.core.models.gpt.gpt_layer_specs import (
from megatron.core.utils import StragglerDetector get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer.spec_utils import import_module from megatron.core.transformer.spec_utils import import_module
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import ( from megatron.training.utils import (
get_batch_on_this_cp_rank, get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank, get_batch_on_this_tp_rank,
get_blend_and_blend_per_split, get_blend_and_blend_per_split,
) )
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from dcu_megatron import megatron_adaptor import megatron.legacy.model # isort: skip
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
try:
from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
from megatron.post_training.loss_func import loss_func as loss_func_modelopt
from megatron.post_training.model_provider import model_provider as model_provider_modelopt
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
from dcu_megatron import megatron_adaptor
stimer = StragglerDetector() stimer = StragglerDetector()
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
def model_provider(
pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model. """Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
...@@ -55,24 +75,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -55,24 +75,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
""" """
args = get_args() args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return model_provider_modelopt(pre_process, post_process)
if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))): if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
assert args.transformer_impl == "transformer_engine" assert args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine" use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history: if args.record_memory_history:
torch.cuda.memory._record_memory_history(True, torch.cuda.memory._record_memory_history(
True,
# keep 100,000 alloc/free events from before the snapshot # keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000, trace_alloc_max_entries=100000,
# record stack information for the trace events # record stack information for the trace events
trace_alloc_record_context=True) trace_alloc_record_context=True,
)
def oom_observer(device, alloc, device_alloc, device_free): def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened # snapshot right after an OOM happened
print('saving allocated state during OOM') print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot() snapshot = torch.cuda.memory._snapshot()
from pickle import dump from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
)
torch._C._cuda_attach_out_of_memory_observer(oom_observer) torch._C._cuda_attach_out_of_memory_observer(oom_observer)
...@@ -97,21 +126,36 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -97,21 +126,36 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
else: else:
if args.num_experts: if args.num_experts:
# Define the decoder block spec # Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization) transformer_layer_spec = get_gpt_decoder_block_spec(
config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else: else:
# Define the decoder layer spec # Define the decoder layer spec
if use_te: if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm, args.num_experts,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm
)
else: else:
transformer_layer_spec = get_gpt_layer_local_spec( transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm, args.num_experts,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm, args.moe_grouped_gemm,
normalization=args.normalization) args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
)
mtp_block_spec = None mtp_block_spec = None
if args.mtp_num_layers is not None: if args.mtp_num_layers is not None:
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te) mtp_block_spec = get_gpt_mtp_block_spec(
config, transformer_layer_spec, use_transformer_engine=use_te, vp_stage=vp_stage
)
model = GPTModel( model = GPTModel(
config=config, config=config,
...@@ -128,16 +172,19 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -128,16 +172,19 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling, rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec, mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
) )
print_rank_0(model)
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
"""Generate a batch.""" """Generate a batch."""
# TODO: this is pretty hacky, find a better way # TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None return None, None, None, None, None
# get batches based on the TP rank you are on # get batches based on the TP rank you are on
...@@ -153,12 +200,15 @@ def get_batch(data_iterator): ...@@ -153,12 +200,15 @@ def get_batch(data_iterator):
SPIKY_LOSS_FACTOR = 10 SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): def loss_func(
loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
"""Loss function. """Loss function.
Args: Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns: Returns:
the loss scalar for this micro-batch the loss scalar for this micro-batch
...@@ -168,26 +218,25 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -168,26 +218,25 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
""" """
args = get_args() args = get_args()
losses = output_tensor.float() if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
loss_mask = loss_mask.view(-1).float() return loss_func_modelopt(loss_mask, output_tensor, model=model)
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1: losses = output_tensor.view(-1).float()
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses * loss_mask)
# Check individual rank losses are not NaN prior to DP all-reduce. # Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
if args.check_for_nan_in_loss_and_grad: if args.check_for_nan_in_loss_and_grad:
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss,
rejection_func=torch.isnan, rejection_func=torch.isnan,
message="found NaN in local forward loss calculation", message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss,
rejection_func=torch.isinf, rejection_func=torch.isinf,
message="found Inf in local forward loss calculation", message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
...@@ -196,7 +245,7 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -196,7 +245,7 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
# Check for spiky loss # Check for spiky loss
if args.check_for_spiky_loss: if args.check_for_spiky_loss:
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss,
rejection_func=partial( rejection_func=partial(
rerun_state_machine.is_unexpectedly_large, rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR, threshold=SPIKY_LOSS_FACTOR,
...@@ -206,19 +255,11 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -206,19 +255,11 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=False, fatal=False,
) )
# Reduce loss for logging.
reporting_loss = loss.clone().detach() num_tokens = loss_mask.sum().clone().detach().to(torch.int)
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) reporting_loss = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error return (loss, num_tokens, {'lm loss': reporting_loss})
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
# on loss[0] fixes this
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0].clone(),
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel): def forward_step(data_iterator, model: GPTModel):
...@@ -235,25 +276,26 @@ def forward_step(data_iterator, model: GPTModel): ...@@ -235,25 +276,26 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator', log_level=2).start() timers('batch-generator', log_level=2).start()
global stimer global stimer
with stimer(bdata=True): with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
with stimer: with stimer:
if args.use_legacy_models: if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
labels=labels)
else: else:
output_tensor = model(tokens, position_ids, attention_mask, output_tensor = model(
labels=labels, loss_mask=loss_mask) tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return output_tensor, partial(loss_func, loss_mask) # [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
def is_dataset_built_on_rank(): def is_dataset_built_on_rank():
return ( return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() parallel_state.is_pipeline_first_stage(ignore_virtual=True)
) and mpu.get_tensor_model_parallel_rank() == 0 or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
) and parallel_state.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args): def core_gpt_dataset_config_from_args(args):
...@@ -278,7 +320,8 @@ def core_gpt_dataset_config_from_args(args): ...@@ -278,7 +320,8 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask, reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss, eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader, create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path, object_storage_cache_path=args.object_storage_cache_path,
mid_level_dataset_surplus=args.mid_level_dataset_surplus,
) )
...@@ -300,10 +343,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -300,10 +343,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0("> building train, validation, and test datasets for GPT ...") print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type, dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build() ).build()
print_rank_0("> finished creating GPT datasets ...") print_rank_0("> finished creating GPT datasets ...")
...@@ -316,10 +356,15 @@ if __name__ == "__main__": ...@@ -316,10 +356,15 @@ if __name__ == "__main__":
# Temporary for transition to core datasets # Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True train_valid_test_datasets_provider.is_distributed = True
# Optionally enable inprocess restart on pretrain
pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)
pretrain( pretrain(
train_valid_test_datasets_provider, train_valid_test_datasets_provider,
model_provider, model_provider,
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
store=store,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment