Commit e103a256 authored by dongcl's avatar dongcl
Browse files

patch for megatron commit 0595ef2b0c93f8d61f473c9f99f9ff73803ff919

parent ade7b0dc
......@@ -10,6 +10,7 @@ from packaging.version import Version as PkgVersion
from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ModelCommProcessGroups
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.extensions.transformer_engine import TEDotProductAttention
......@@ -20,7 +21,6 @@ from megatron.core.extensions.transformer_engine import TELinear as MegatronCore
from megatron.core.extensions.transformer_engine import TELayerNormColumnParallelLinear as MegatronCoreTELayerNormColumnParallelLinear
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
......@@ -69,6 +69,8 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
symmetric_ar_type: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
......@@ -90,6 +92,8 @@ class TELinear(MegatronCoreTELinear):
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
is_expert=is_expert,
symmetric_ar_type=symmetric_ar_type,
tp_group=tp_group,
)
def backward_dw(self):
......@@ -118,6 +122,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
args = get_args()
self.split_bw = args.split_bw if hasattr(args, "split_bw") else False
......@@ -139,6 +144,7 @@ class TELayerNormColumnParallelLinear(MegatronCoreTELayerNormColumnParallelLinea
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)
def backward_dw(self):
......
......@@ -282,6 +282,7 @@ def forward_backward_step(
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
vp_stage=None,
encoder_decoder_xattn=False,
):
"""Forward step for passed-in model.
......@@ -345,6 +346,8 @@ def forward_backward_step(
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
vp_stage (int, optional):
The virtual pipeline stage. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
......@@ -435,13 +438,19 @@ def forward_backward_step(
num_tokens = None
if f_model:
with f_context:
model_vp_stage = getattr(f_model, "vp_stage", None)
if vp_stage is not None and model_vp_stage is not None:
assert (
vp_stage == model_vp_stage
), f"vp_stage ({vp_stage}) doesn't match model_vp_stage ({model_vp_stage})"
num_tokens = torch.tensor(0, dtype=torch.int)
args = get_args()
is_last_stage = False
if args.schedule_method == "dualpipev":
is_last_stage = parallel_state.is_pipeline_first_stage() and get_dualpipe_chunk() == 1
else:
is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False)
is_last_stage = parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage)
if is_last_stage:
if not collect_non_loss_data:
loss_node = ScheduleNode(
......
......@@ -1111,11 +1111,9 @@ def forward_backward_pipelining_with_cutinhalf(
output_tensor_grad, _ = recv_backward(tensor_shape, config, master_chunk_id)
output_tensor_grads[master_chunk_id].append(output_tensor_grad)
input_tensor_grad = backward_step_helper(
master_chunk_id,
bwd_cur_microbatch=cur_bwd_chunk_microbatch[master_chunk_id]
_, input_tensor_grad = forward_backward_helper_wrapper(
bwd_model_chunk_id=master_chunk_id,
)
cur_bwd_chunk_microbatch[master_chunk_id] += 1
_ = send_backward(
input_tensor_grad,
......
import contextlib
from typing import Iterator, List, Union
from functools import partial
from typing import Callable, Iterator, List, Optional, Union
import torch
......@@ -12,6 +13,8 @@ from megatron.core.utils import (
get_model_config,
get_model_type,
get_model_xattn,
nvtx_range_pop,
nvtx_range_push,
)
from megatron.core.pipeline_parallel.schedules import (
forward_step,
......@@ -82,10 +85,11 @@ def forward_backward_pipelining_with_interleaving(
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
decoder_seq_length: Optional[int] = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
first_val_step: Optional[bool] = None,
adjust_tensor_shapes_fn: Optional[Callable] = None, # unused
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
......@@ -106,6 +110,9 @@ def forward_backward_pipelining_with_interleaving(
assert isinstance(
data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
assert (
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"
config = get_model_config(model[0])
......@@ -373,11 +380,8 @@ def forward_backward_pipelining_with_interleaving(
def forward_step_helper(
virtual_microbatch_id, microbatch_id, checkpoint_activations_microbatch
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
"""Helper method to run forward step with model split into chunks"""
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
......@@ -399,7 +403,7 @@ def forward_backward_pipelining_with_interleaving(
)
# forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
......@@ -427,6 +431,7 @@ def forward_backward_pipelining_with_interleaving(
is_first_microbatch_for_model_chunk(virtual_microbatch_id),
),
current_microbatch=microbatch_id,
vp_stage=model_chunk_id,
)
output_tensors[model_chunk_id].append(output_tensor)
......@@ -443,13 +448,8 @@ def forward_backward_pipelining_with_interleaving(
return output_tensor
def backward_step_helper(virtual_microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
nonlocal output_tensor_grads
"""Helper method to run backward step with model split into chunks"""
model_chunk_id = get_model_chunk_id(virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(
......@@ -459,7 +459,7 @@ def forward_backward_pipelining_with_interleaving(
synchronized_model_chunks.add(model_chunk_id)
# pylint: disable=E0606
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
......@@ -509,7 +509,7 @@ def forward_backward_pipelining_with_interleaving(
if f_virtual_microbatch_id is not None:
model_chunk_id = get_model_chunk_id(f_virtual_microbatch_id, forward=True)
f_model_chunk_id = model_chunk_id
f_context = VppContextManager(f_model_chunk_id)
# f_context = VppContextManager(f_model_chunk_id)
with f_context:
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
......@@ -533,7 +533,7 @@ def forward_backward_pipelining_with_interleaving(
)
# forward step
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if parallel_state.is_pipeline_first_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
......@@ -556,7 +556,7 @@ def forward_backward_pipelining_with_interleaving(
if b_virtual_microbatch_id is not None:
model_chunk_id = get_model_chunk_id(b_virtual_microbatch_id, forward=False)
b_model_chunk_id = model_chunk_id
b_context = VppContextManager(b_model_chunk_id)
# b_context = VppContextManager(b_model_chunk_id)
with b_context:
# launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(
......@@ -565,7 +565,7 @@ def forward_backward_pipelining_with_interleaving(
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=model_chunk_id):
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
b_input_tensor = input_tensors[model_chunk_id].pop(0)
......@@ -602,6 +602,7 @@ def forward_backward_pipelining_with_interleaving(
),
),
current_microbatch=f_microbatch_id,
vp_stage=f_model_chunk_id,
)
# forward post process
......@@ -675,8 +676,6 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = None
if f_virtual_microbatch_id is not None:
# forward pass
forward_model_chunk_id = get_model_chunk_id(f_virtual_microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if pre_forward is not None:
pre_forward()
......@@ -689,8 +688,6 @@ def forward_backward_pipelining_with_interleaving(
if b_virtual_microbatch_id is not None:
# Backward pass.
backward_model_chunk_id = get_model_chunk_id(b_virtual_microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if pre_backward is not None:
pre_backward()
input_tensor_grad = backward_step_helper(b_virtual_microbatch_id)
......@@ -698,9 +695,15 @@ def forward_backward_pipelining_with_interleaving(
input_tensor_grad = post_backward(input_tensor_grad)
return output_tensor, input_tensor_grad
is_vp_first_stage = partial(parallel_state.is_pipeline_first_stage, ignore_virtual=False)
is_vp_last_stage = partial(parallel_state.is_pipeline_last_stage, ignore_virtual=False)
# Run warmup forward passes.
nvtx_range_push(suffix="warmup")
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
input_tensors[0].append(
p2p_communication.recv_forward(tensor_shape, config, is_vp_first_stage(vp_stage=0))
)
fwd_wait_handles = None
fwd_wait_recv_handles = None
......@@ -727,10 +730,9 @@ def forward_backward_pipelining_with_interleaving(
for k in range(num_warmup_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if config.overlap_p2p_comm_warmup_flush:
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False) and k != 0:
if not is_vp_first_stage(vp_stage=cur_model_chunk_id) and k != 0:
assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, iteration {k},'
'should have registered recv handle'
......@@ -777,7 +779,7 @@ def forward_backward_pipelining_with_interleaving(
)
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if is_vp_last_stage(vp_stage=cur_model_chunk_id):
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
......@@ -880,8 +882,10 @@ def forward_backward_pipelining_with_interleaving(
if recv_next:
output_tensor_grads[num_model_chunks - 1].append(bwd_recv_buffer[-1])
nvtx_range_pop(suffix="warmup")
# Run 1F1B in steady state.
nvtx_range_push(suffix="steady")
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
......@@ -895,14 +899,15 @@ def forward_backward_pipelining_with_interleaving(
else:
checkpoint_activations_microbatch = None
microbatch_id = get_microbatch_id_in_model_chunk(forward_k, forward=True)
if config.overlap_p2p_comm:
# output send / receive sync
def pp_pre_forward():
nonlocal recv_prev_wait_handles
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if not is_vp_first_stage(vp_stage=cur_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush:
assert recv_prev_wait_handles, (
f'pp rank {pipeline_parallel_rank}, fwd iteration {forward_k}, '
......@@ -917,7 +922,6 @@ def forward_backward_pipelining_with_interleaving(
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# output async send / receive
def pp_post_forward(output_tensor):
nonlocal send_next_wait_handle
nonlocal fwd_recv_buffer
......@@ -927,10 +931,9 @@ def forward_backward_pipelining_with_interleaving(
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send.
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -963,8 +966,6 @@ def forward_backward_pipelining_with_interleaving(
recv_prev_wait_handles.append(fwd_wait_handles.pop("recv_prev"))
# assert fwd_wait_handles is not None
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(
fwd_recv_buffer[forward_k % fwd_recv_buffer_size]
......@@ -973,14 +974,13 @@ def forward_backward_pipelining_with_interleaving(
return output_tensor
# Backward pass.
backward_k = k
# grad send receive sync
def pp_pre_backward():
nonlocal recv_next_wait_handles
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if not is_vp_last_stage(vp_stage=backward_model_chunk_id):
if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, bwd iteration {backward_k}, '
......@@ -1000,11 +1000,9 @@ def forward_backward_pipelining_with_interleaving(
nonlocal recv_next_wait_handles
nonlocal bwd_recv_buffer
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None
recv_next, next_backward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1036,6 +1034,7 @@ def forward_backward_pipelining_with_interleaving(
bwd_recv_buffer[backward_k % bwd_recv_buffer_size]
)
bwd_recv_buffer[(backward_k + 1) % bwd_recv_buffer_size] = None
return input_tensor_grad
output_tensor, input_tensor_grad = forward_backward_helper_wrapper(
......@@ -1061,13 +1060,11 @@ def forward_backward_pipelining_with_interleaving(
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage(ignore_virtual=False):
if is_vp_last_stage(vp_stage=forward_model_chunk_id):
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if is_vp_first_stage(vp_stage=backward_model_chunk_id):
input_tensor_grad = None
recv_prev, next_forward_model_chunk_id = recv_tensor_from_previous_stage(
......@@ -1104,8 +1101,11 @@ def forward_backward_pipelining_with_interleaving(
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
nvtx_range_pop(suffix="steady")
# Run cooldown backward passes (flush out pipeline).
# Run cooldown backward passes (flush out pipeline) for the last model chunk.
nvtx_range_push(suffix="cooldown")
curr_vp_stage = config.virtual_pipeline_model_parallel_size - 1
if not forward_only:
if bwd_wait_handles is not None:
for bwd_wait_handle in bwd_wait_handles.values():
......@@ -1113,12 +1113,15 @@ def forward_backward_pipelining_with_interleaving(
if are_all_microbatches_in_warmup:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
p2p_communication.recv_backward(
tensor_shape,
config=config,
is_last_stage=is_vp_last_stage(vp_stage=curr_vp_stage),
)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
cur_model_chunk_id = get_model_chunk_id(k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(cur_model_chunk_id)
if not parallel_state.is_pipeline_last_stage(ignore_virtual=False) and k != 0:
if not is_vp_last_stage(vp_stage=cur_model_chunk_id) and k != 0:
if config.overlap_p2p_comm_warmup_flush:
assert recv_next_wait_handles, (
f'pp rank {pipeline_parallel_rank}, backward iteration {k}, '
......@@ -1158,7 +1161,7 @@ def forward_backward_pipelining_with_interleaving(
_, input_tensor_grad = forward_backward_helper_wrapper(b_virtual_microbatch_id=k)
# First virtual stage no activation gradient tensor to send.
if parallel_state.is_pipeline_first_stage(ignore_virtual=False):
if is_vp_first_stage(vp_stage=cur_model_chunk_id):
input_tensor_grad = None
if config.overlap_p2p_comm_warmup_flush:
......@@ -1215,7 +1218,9 @@ def forward_backward_pipelining_with_interleaving(
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
nvtx_range_pop(suffix="cooldown")
nvtx_range_push(suffix="misc")
assert (
not recv_prev_wait_handles
), 'recv_prev_wait_handles should be cleared at the end of a step'
......@@ -1245,5 +1250,6 @@ def forward_backward_pipelining_with_interleaving(
if hasattr(config, 'enable_cuda_graph') and config.enable_cuda_graph:
create_cudagraphs()
nvtx_range_pop(suffix="misc")
return forward_data_store
......@@ -87,6 +87,12 @@ def parallel_attention_init_wrapper(fn):
return wrapper
class ParallelAttentionPatch(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_context=None,
rotary_pos_emb=None, *, inference_params=None):
......
......@@ -6,6 +6,15 @@ from functools import wraps
import torch.distributed
import torch
try:
from megatron.post_training.algos.distillation import (
get_tensor_shapes_adjust_fn_for_distillation,
)
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
from megatron.core import mpu
from megatron.core.utils import (
check_param_hashes_across_dp_replicas,
......@@ -92,8 +101,7 @@ def build_train_valid_test_data_iterators_wrapper(build_train_valid_test_data_it
return wrapper
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -118,6 +126,14 @@ def train_step(forward_step_func, data_iterator,
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
if has_nvidia_modelopt:
# [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors
adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(
model, args.seq_length, args.micro_batch_size, args.decoder_seq_length
)
else:
adjust_tensor_shapes_fn = None
# Forward pass.
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
......@@ -128,7 +144,9 @@ def train_step(forward_step_func, data_iterator,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=False)
forward_only=False,
adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,
)
should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
if should_exit:
return {}, True, should_checkpoint, should_exit, exit_code, None, None
......@@ -164,9 +182,7 @@ def train_step(forward_step_func, data_iterator,
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
......@@ -189,28 +205,46 @@ def train_step(forward_step_func, data_iterator,
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0].keys():
numerator = 0
denominator = 0
for x in losses_reduced:
val = x[key]
val = [x[key].view(-1) for x in losses_reduced]
if val[0].numel() == 2:
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
if isinstance(val, tuple) or isinstance(val, list):
numerator += val[0]
denominator += val[1]
else:
# legacy behavior. we average over the number of microbatches,
# and so the denominator is 1.
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
val = torch.vstack(val).sum(dim=0)
torch.distributed.all_reduce(
val,
group=mpu.get_data_parallel_group(with_context_parallel=True)
)
loss_reduced[key] = val[0] / val[1]
elif val[0].numel() == 1:
# legacy behavior, we average over the number of microbatches
val = torch.cat(val).mean()
loss_reduced[key] = val
else:
raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}")
return (
loss_reduced,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
)
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
def train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
......@@ -220,7 +254,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
try:
from workload_inspector.utils.webserver import run_server
import threading
threading.Thread(target=run_server, daemon=True, args=(torch.distributed.get_rank(), )).start()
threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
except ModuleNotFoundError:
print_rank_0("workload inspector module not found.")
......@@ -243,11 +280,17 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples, seq_length=args.seq_length,
train_iters=args.train_iters, save=args.save, async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
one_logger_utils.on_train_start(
iteration=iteration,
consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples,
seq_length=args.seq_length,
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
......@@ -255,9 +298,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
assert config.no_sync_func is None, (
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
......@@ -281,8 +325,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert args.manual_gc_interval >= 0, \
'Manual garbage collection interval should be larger than or equal to 0'
assert (
args.manual_gc_interval >= 0
), 'Manual garbage collection interval should be larger than or equal to 0'
gc.disable()
gc.collect()
......@@ -292,10 +337,13 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count
stimer.configure(world, rank,
mmcnt = mmcnt,
enabled = not args.disable_straggler_on_startup,
port = args.straggler_ctrlr_port)
stimer.configure(
world,
rank,
mmcnt=mmcnt,
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches()
......@@ -303,10 +351,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
eval_iterations = 0
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start = (
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
......@@ -316,15 +364,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length
'seq_length': args.seq_length,
}
# Cache into one-logger for callback.
if one_logger:
with one_logger.get_context_manager():
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
......@@ -345,14 +398,15 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
wait=max(args.profile_step_start - 1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
active=args.profile_step_end - args.profile_step_start,
repeat=1,
),
on_trace_ready=trace_handler,
record_shapes=True,
with_stack=True,
)
)
prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
import ctypes
......@@ -371,8 +425,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
assert check_param_hashes_across_dp_replicas(
model, cross_check=True
), "Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
......@@ -398,14 +453,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
assert get_num_microbatches() > num_microbatches, (
f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}"
)
if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
......@@ -414,9 +475,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
......@@ -424,19 +485,28 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
(
loss_dict,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
if should_exit:
break
......@@ -459,12 +529,13 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
pre_hook_enabled = True
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
num_skipped_samples_in_batch = (
get_current_global_batch_size() - get_current_running_global_batch_size()
)
if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0
else:
......@@ -490,16 +561,22 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
report_memory_flag = training_log(
loss_dict,
total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
timers('interval-time').stop()
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
......@@ -509,11 +586,18 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
evaluate_and_print_results(
prefix,
forward_step_func,
valid_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -529,13 +613,25 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
post_training_step_callbacks(
model,
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
should_exit = checkpoint_and_decide_exit(
model,
optimizer,
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator,
)
if should_exit:
break
......@@ -564,6 +660,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import datetime
import os
import torch
from functools import partial
from contextlib import nullcontext
import inspect
from functools import partial
from typing import List, Optional, Tuple, Union
from megatron.core import parallel_state
from megatron.training import get_args
from megatron.training import inprocess_restart
from megatron.training import print_rank_0
from megatron.training import get_timers
from megatron.training import get_tokenizer
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset
from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.utils import StragglerDetector
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer.spec_utils import import_module
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
get_blend_and_blend_per_split,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from dcu_megatron import megatron_adaptor
import megatron.legacy.model # isort: skip
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
try:
from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
from megatron.post_training.loss_func import loss_func as loss_func_modelopt
from megatron.post_training.model_provider import model_provider as model_provider_modelopt
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
from dcu_megatron import megatron_adaptor
stimer = StragglerDetector()
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
def model_provider(
pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
......@@ -55,24 +75,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
"""
args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return model_provider_modelopt(pre_process, post_process)
if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
assert args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history:
torch.cuda.memory._record_memory_history(True,
torch.cuda.memory._record_memory_history(
True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
# record stack information for the trace events
trace_alloc_record_context=True)
trace_alloc_record_context=True,
)
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
)
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
......@@ -91,27 +120,42 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization)
transformer_layer_spec = get_gpt_decoder_block_spec(
config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else:
# Define the decoder layer spec
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm
)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm,
normalization=args.normalization)
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
)
mtp_block_spec = None
if args.mtp_num_layers is not None:
mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te)
mtp_block_spec = get_gpt_mtp_block_spec(
config, transformer_layer_spec, use_transformer_engine=use_te, vp_stage=vp_stage
)
model = GPTModel(
config=config,
......@@ -128,16 +172,19 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
)
print_rank_0(model)
return model
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None
# get batches based on the TP rank you are on
......@@ -153,12 +200,15 @@ def get_batch(data_iterator):
SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
def loss_func(
loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
"""Loss function.
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns:
the loss scalar for this micro-batch
......@@ -168,57 +218,48 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return loss_func_modelopt(loss_mask, output_tensor, model=model)
if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
losses = output_tensor.view(-1).float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses * loss_mask)
# Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine()
if args.check_for_nan_in_loss_and_grad:
rerun_state_machine.validate_result(
result=loss[0],
result=loss,
rejection_func=torch.isnan,
message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
result=loss,
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss[0],
result=loss,
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
)
# Reduce loss for logging.
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
# on loss[0] fixes this
local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0].clone(),
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
num_tokens = loss_mask.sum().clone().detach().to(torch.int)
reporting_loss = torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])
return (loss, num_tokens, {'lm loss': reporting_loss})
def forward_step(data_iterator, model: GPTModel):
......@@ -235,25 +276,26 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers('batch-generator').stop()
with stimer:
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels, loss_mask=loss_mask)
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return output_tensor, partial(loss_func, loss_mask)
# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
def is_dataset_built_on_rank():
return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
) and parallel_state.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
......@@ -278,7 +320,8 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path,
object_storage_cache_path=args.object_storage_cache_path,
mid_level_dataset_surplus=args.mid_level_dataset_surplus,
)
......@@ -300,10 +343,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type,
train_val_test_num_samples,
is_dataset_built_on_rank,
config
dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
).build()
print_rank_0("> finished creating GPT datasets ...")
......@@ -316,10 +356,15 @@ if __name__ == "__main__":
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
# Optionally enable inprocess restart on pretrain
pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
store=store,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment