Commit 688448db authored by silencealiang's avatar silencealiang
Browse files

更新代码

parent a02a5490
Pipeline #2503 passed with stage
......@@ -10,6 +10,9 @@ import logging
import math
import os
import sys
from typing import List
import torch.distributed
from .log_handler import CustomHandler
# Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
......@@ -32,6 +35,7 @@ from megatron.training.checkpointing import checkpoint_exists
from megatron.legacy.model import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
......@@ -51,6 +55,10 @@ from megatron.core.rerun_state_machine import (
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.initialize import set_jit_fusion_options
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from megatron.legacy.data.data_samplers import build_pretraining_data_loader
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.transformer.moe import upcycling_utils
......@@ -127,6 +135,10 @@ def num_floating_point_operations(args, batch_size):
if args.moe_shared_expert_intermediate_size is None
else args.moe_shared_expert_intermediate_size
)
if args.num_experts is None:
ffn_hidden_size = args.ffn_hidden_size
else:
ffn_hidden_size = args.moe_ffn_hidden_size
# The 12x term below comes from the following factors; for more details, see
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
......@@ -156,7 +168,7 @@ def num_floating_point_operations(args, batch_size):
)
# MLP.
+ (
(args.ffn_hidden_size / args.hidden_size)
(ffn_hidden_size / args.hidden_size)
* num_experts_routed_to
* gated_linear_multiplier
)
......@@ -283,6 +295,12 @@ def pretrain(
if args.log_progress:
append_to_progress_log("Starting job")
# Initialize fault tolerance
# NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized
if args.enable_ft_package:
ft_integration.setup(args)
ft_integration.maybe_setup_simulated_fault()
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
......@@ -311,10 +329,28 @@ def pretrain(
# Context used for persisting some state between checkpoint saves.
if args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
LocalCheckpointManager
from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import \
parse_group_sequence, GroupWrapper
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import \
CliqueReplicationStrategy
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
if args.replication:
repl_strategy = CliqueReplicationStrategy.from_replication_params(
args.replication_jump,
args.replication_factor
)
else:
repl_strategy = None
checkpointing_context = {
'local_checkpoint_manager': BasicLocalCheckpointManager(
args.non_persistent_local_ckpt_dir
'local_checkpoint_manager': LocalCheckpointManager(args.non_persistent_local_ckpt_dir,
repl_strategy=repl_strategy
)
}
else:
......@@ -360,11 +396,6 @@ def pretrain(
args.do_valid, args.do_test, args.dataloader_type,
args.retro_project_dir, args.retro_cyclic_train_iters)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().init_workload_monitoring()
ft_timeouts = ft_integration.get_rank_monitor_client().timeouts
print_rank_0(f"Fault tolerance client initialized. Timeouts: {ft_timeouts}")
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup',
......@@ -396,8 +427,7 @@ def pretrain(
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
preprocess_common_state_dict_fn=preprocess_common_state_dict)
one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
......@@ -427,11 +457,16 @@ def pretrain(
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
maybe_finalize_async_save(blocking=True)
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
one_logger and one_logger.log_metrics({
'app_finish_time': one_logger_utils.get_timestamp_in_ms()
})
ft_integration.shutdown()
one_logger_utils.finish()
......@@ -472,6 +507,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.model_type = model_type
# Build model.
def build_model():
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
......@@ -513,6 +549,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
post_process=post_process
)
model.model_type = model_type
return model
if args.init_model_with_meta_device:
with torch.device('meta'):
model = build_model()
else:
model = build_model()
if not isinstance(model, list):
model = [model]
......@@ -526,15 +568,21 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
num_parameters = sum(
[sum([p.nelement() for p in model_module.parameters()])
for model_module in model]
)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
num_parameters), flush=True)
# GPU allocation.
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
# in the fully_shard function of FSDP2 instead.
if not (args.use_torch_fsdp2 and args.use_cpu_initialization) and not args.init_model_with_meta_device:
for model_module in model:
model_module.cuda(torch.cuda.current_device())
......@@ -558,9 +606,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
fp8_meta.amax_history[0][fp8_meta_index] = 0
if wrap_with_ddp:
if getattr(args, "use_torch_fsdp2", False):
if args.use_torch_fsdp2:
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP
elif args.use_custom_fsdp:
DP = custom_FSDP
else:
DP = DDP
......@@ -572,17 +622,42 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
kwargs['check_for_large_grads'] = args.check_for_large_grads
if args.ddp_num_buckets is not None:
assert args.ddp_bucket_size is None, \
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
assert args.ddp_num_buckets > 0, \
"--ddp-num-buckets must be greater than 0"
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
else:
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
kwargs['average_in_collective'] = args.ddp_average_in_collective
if args.use_custom_fsdp and args.use_precision_aware_optimizer:
kwargs["preserve_fp32_weights"] = False
ddp_config = DistributedDataParallelConfig(**kwargs)
overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False)
if not getattr(args, "use_torch_fsdp2", False):
# In the custom FSDP and DDP use path, we need to initialize the bucket size.
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
model = [DP(config=config,
ddp_config=ddp_config,
module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step)
disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step)
for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks.
......@@ -670,7 +745,8 @@ def setup_model_and_optimizer(model_provider_func,
config = OptimizerConfig(**kwargs)
config.timers = timers
optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
scale_lr_cond, lr_mult,
use_gloo_process_groups=args.enable_gloo_process_groups)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.moe_use_upcycling:
......@@ -709,9 +785,8 @@ def setup_model_and_optimizer(model_provider_func,
timers('load-checkpoint', log_level=0).start(barrier=True)
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
model, optimizer, opt_param_scheduler,
ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and getattr(args, "use_torch_fsdp2", False))
model, optimizer, opt_param_scheduler, checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and args.use_torch_fsdp2)
timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint'])
one_logger and one_logger.log_metrics({
......@@ -748,6 +823,15 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler
def dummy_train_step(data_iterator):
"""Single dummy training step."""
num_microbatches = get_num_microbatches()
for _ in range(num_microbatches):
# Re-use methods used in get_batch() from pretrain_{gpt, mamba}.py.
batch = get_batch_on_this_tp_rank(data_iterator)
batch = get_batch_on_this_cp_rank(batch)
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
......@@ -781,7 +865,7 @@ def train_step(forward_step_func, data_iterator,
torch.cuda.empty_cache()
# Vision gradients.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
......@@ -801,7 +885,7 @@ def train_step(forward_step_func, data_iterator,
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
......@@ -927,12 +1011,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path , 'wb') as f:
dump(snapshot, f)
if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
......@@ -1000,6 +1078,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-max-allocated-bytes",
mem_stats["allocated_bytes.all.peak"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
......@@ -1010,6 +1093,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
if iteration % args.log_interval == 0:
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path, 'wb') as f:
dump(snapshot, f)
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
......@@ -1143,26 +1232,23 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
# Extra barrier is added to make sure all ranks report the max time.
timer_key = 'save-checkpoint-non-persistent' if non_persistent_ckpt else 'save-checkpoint'
timers(timer_key, log_level=0).start(barrier=True)
save_checkpoint_start_time = timers('save-checkpoint').active_time()
# Log E2E metrics before save-checkpoint
one_logger_utils.track_e2e_metrics()
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
non_persistent_ckpt=non_persistent_ckpt, train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
if args.use_distributed_optimizer and args.overlap_param_gather:
preprocess_common_state_dict_fn=preprocess_common_state_dict)
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
timers(timer_key).stop(barrier=True)
timers.log([timer_key])
save_checkpoint_finish_time = timers('save-checkpoint').active_time()
# Log E2E metrics after save-checkpoint
one_logger_utils.track_e2e_metrics()
save_checkpoint_duration = save_checkpoint_finish_time - save_checkpoint_start_time
save_checkpoint_duration = timers(timer_key).elapsed()
one_logger_utils.on_save_checkpoint_end(save_checkpoint_duration, iteration, args.async_save)
if args.log_progress and not non_persistent_ckpt:
......@@ -1178,21 +1264,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
"""Run all post-training-step functions (e.g., FT heartbeats, GC)."""
args = get_args()
# Send heartbeat to FT package and update timeouts.
if args.enable_ft_package:
ft_client = ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.TRAIN_HEARTBEAT)
if ft_client is not None:
ft_client.send_heartbeat()
# TODO: We are always calculating timeouts in the current implementation.
# If we want to rely on manually setting these, then we need to add additional
# arguments to training and pass it here.
if ft_integration.can_update_timeouts():
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.UPDATE_TIMEOUT).calculate_and_set_timeouts()
print_rank_0(f'Updated FT timeouts. New values: \
{ft_integration.get_rank_monitor_client().timeouts}')
# Bring CPU and GPU back in sync if on right iteration.
if args.train_sync_interval and iteration % args.train_sync_interval == 0:
torch.cuda.synchronize()
......@@ -1205,13 +1276,13 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
# Check weight hash across DP replicas.
if args.check_weight_hash_across_dp_replicas_interval is not None and \
iteration % args.check_weight_hash_across_dp_replicas_interval == 0:
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
# Autoresume.
......@@ -1270,14 +1341,12 @@ def checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
elif args.save and args.non_persistent_save_interval and \
iteration % args.non_persistent_save_interval == 0:
timers('interval-time').stop()
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
non_persistent_ckpt=True, train_data_iterator=train_data_iterator)
saved_checkpoint = True
timers('interval-time', log_level=0).start(barrier=True)
# Exit based on duration.
if args.exit_duration_in_mins:
......@@ -1333,6 +1402,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration = args.iteration
# Make sure rerun_state_machine has the right iteration loaded from checkpoint.
rerun_state_machine = get_rerun_state_machine()
if rerun_state_machine.current_iteration != iteration:
print_rank_0(f"Setting rerun_state_machine.current_iteration to {iteration}...")
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
......@@ -1346,7 +1420,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Setup some training config params.
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], DDP) and args.overlap_grad_reduce:
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
......@@ -1397,12 +1471,14 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
'eval_duration': eval_duration,
'eval_iterations': eval_iterations,
'total_flops': num_floating_point_operations_since_last_log_event,
'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
......@@ -1415,7 +1491,6 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
......@@ -1444,15 +1519,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready=trace_handler)
prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
import ctypes
roctracer = ctypes.cdll.LoadLibrary("/opt/dtk/roctracer/lib/libroctracer64.so")
start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
......@@ -1471,14 +1543,13 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler:
prof.step()
elif args.use_hip_profiler:
if iteration == args.profile_step_start: roctracer.roctracer_start()
if iteration == args.profile_step_end: roctracer.roctracer_stop()
elif iteration == args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=False)
ft_integration.on_checkpointing_end(is_async_finalization=True)
# Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different
......@@ -1497,8 +1568,21 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
# Completely skip iteration if needed.
if iteration in args.iterations_to_skip:
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
......@@ -1506,6 +1590,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
optimizer,
opt_param_scheduler,
config)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
......@@ -1527,7 +1612,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
......@@ -1575,7 +1660,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
......@@ -1596,15 +1681,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat()
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
......@@ -1628,16 +1709,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if pre_hook_enabled:
disable_forward_pre_hook(model)
ft_integration.on_checkpointing_start()
# This will finalize all unfinalized async request and terminate
# a persistent async worker if persistent ckpt worker is enabled
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().shutdown_workload_monitoring()
maybe_finalize_async_save(blocking=True)
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit:
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
......@@ -1688,6 +1773,7 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func()
# Don't care about timing during evaluation
config.timers = None
ft_integration.on_eval_step_start()
loss_dicts = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
......@@ -1697,6 +1783,7 @@ def evaluate(forward_step_func,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=True)
ft_integration.on_eval_step_end()
config.timers = get_timers()
# Empty unused memory
......@@ -1955,3 +2042,8 @@ def build_train_valid_test_data_iterators(
test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator
def should_disable_forward_pre_hook(args):
"""Block forward pre-hook for certain configurations."""
return not args.use_custom_fsdp and args.use_distributed_optimizer and args.overlap_param_gather
......@@ -33,6 +33,7 @@ from megatron.training import (
get_adlr_autoresume,
)
from megatron.core import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
from megatron.core import mpu
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
......@@ -46,9 +47,9 @@ from megatron.legacy.model.module import param_is_not_shared
try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module)
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, custom_FSDP, Float16Module)
except ImportError:
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, custom_FSDP, Float16Module)
def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
......@@ -66,7 +67,7 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return unwrapped_model
def calc_params_l2_norm(model):
def calc_params_l2_norm(model, force_create_fp32_copy=False):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
......@@ -74,57 +75,110 @@ def calc_params_l2_norm(model):
# Seperate moe and dense params
params_data = []
moe_params_data = []
sharded_params_data = []
data_parallel_group = None
custom_fsdp_all_param_is_shared = False
for model_chunk in model:
for i, param in enumerate(model_chunk.parameters()):
for param in model_chunk.parameters():
data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if not (param.requires_grad and is_not_tp_duplicate):
if not is_not_tp_duplicate:
continue
assert is_not_tp_duplicate
if hasattr(param, "fully_shard_param_local_shard"):
param = param.fully_shard_param_local_shard
assert [getattr(p, "fully_shard_param_local_shard", None) is not None for p in model_chunk.parameters()]
custom_fsdp_all_param_is_shared = True
if param.numel() == 0:
continue
if not getattr(param, 'allreduce', True):
# TODO: Implement memory optimization for MoE parameters.
assert param_is_not_shared(param)
param = to_local_if_dtensor(param)
moe_params_data.append(param.data.float() if args.bf16 else param.data)
else:
if param_is_not_shared(param):
param = to_local_if_dtensor(param)
params_data.append(param.data.float() if args.bf16 else param.data)
if args.bf16:
if not force_create_fp32_copy and hasattr(param, 'main_param'):
if getattr(param, 'main_param_sharded', False):
if param.main_param is not None:
sharded_params_data.append(param.main_param)
else:
params_data.append(param.main_param)
else:
# Fallback to original logic of making a fp32 copy of the
# parameter if `.main_param` attribute is not available.
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate dense param norm
# Calculate norm.
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
if len(params_data) > 0:
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
False # no per-parameter norm.
)
norm_2 = norm * norm
else:
norm_2 = torch.tensor([0.0], dtype=torch.float32, device='cuda')
norm_2 = torch.zeros((1,), dtype=torch.float32, device='cuda')
if data_parallel_group is not None:
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=data_parallel_group)
# Sum across all model-parallel GPUs(tensor + pipeline).
# Add norm contribution from params with sharded main_params. These norms need to be
# accumulated across the DP group since the main parameters are sharded because
# of distributed optimizer.
if len(sharded_params_data) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
sharded_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[sharded_params_data],
False # no per-parameter norm.
)
sharded_norm_2 = sharded_norm * sharded_norm
# Sum over all DP groups.
torch.distributed.all_reduce(
sharded_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group()
)
norm_2 += sharded_norm_2
if custom_fsdp_all_param_is_shared:
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group())
# Sum across all model-parallel GPUs (tensor + pipeline).
torch.distributed.all_reduce(
norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()
)
# Calculate moe norm
# Add norm contribution from expert layers in MoEs.
if len(moe_params_data) > 0:
moe_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[moe_params_data],
False # no per-parameter norm
False # no per-parameter norm.
)
moe_norm_2 = moe_norm * moe_norm
if custom_fsdp_all_param_is_shared:
torch.distributed.all_reduce(moe_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_expert_data_parallel_group())
# Sum across expert tensor, model and pipeline parallel GPUs.
torch.distributed.all_reduce(
moe_norm_2,
......@@ -132,6 +186,7 @@ def calc_params_l2_norm(model):
group=mpu.get_expert_tensor_model_pipeline_parallel_group()
)
norm_2 += moe_norm_2
return norm_2.item() ** 0.5
......@@ -304,6 +359,10 @@ def print_rank_0(message):
else:
print(message, flush=True)
def is_rank0():
"""Returns true if called in the rank0, false otherwise"""
return torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
......@@ -316,6 +375,9 @@ def print_rank_last(message):
else:
print(message, flush=True)
def get_device_arch_version():
"""Returns GPU arch version (8: Ampere, 9: Hopper, 10: Blackwell, ...)"""
return torch.cuda.get_device_properties(torch.device("cuda:0")).major
def append_to_progress_log(string, barrier=True):
"""Append given string to progress log."""
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from pathlib import Path
from typing import Tuple
from megatron.training.global_vars import get_wandb_writer
from megatron.training.utils import print_rank_last
def _get_wandb_artifact_tracker_filename(save_dir: str) -> Path:
"""Wandb artifact tracker file rescords the latest artifact wandb entity and project"""
return Path(save_dir) / "latest_wandb_artifact_path.txt"
def _get_artifact_name_and_version(save_dir: Path, checkpoint_path: Path) -> Tuple[str, str]:
return save_dir.stem, checkpoint_path.stem
def on_save_checkpoint_success(checkpoint_path: str, tracker_filename: str, save_dir: str, iteration: int) -> None:
"""Function to be called after checkpointing succeeds and checkpoint is persisted for logging it as an artifact in W&B
Args:
checkpoint_path (str): path of the saved checkpoint
tracker_filename (str): path of the tracker filename for the checkpoint iteration
save_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""
wandb_writer = get_wandb_writer()
if wandb_writer:
metadata = {"iteration": iteration}
artifact_name, artifact_version = _get_artifact_name_and_version(Path(save_dir), Path(checkpoint_path))
artifact = wandb_writer.Artifact(artifact_name, type="model", metadata=metadata)
artifact.add_reference(f"file://{checkpoint_path}", checksum=False)
artifact.add_file(tracker_filename)
wandb_writer.run.log_artifact(artifact, aliases=[artifact_version])
wandb_tracker_filename = _get_wandb_artifact_tracker_filename(save_dir)
wandb_tracker_filename.write_text(f"{wandb_writer.run.entity}/{wandb_writer.run.project}")
def on_load_checkpoint_success(checkpoint_path: str, load_dir: str) -> None:
"""Function to be called after succesful loading of a checkpoint, for aggregation and logging it to W&B
Args:
checkpoint_path (str): path of the loaded checkpoint
load_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""
wandb_writer = get_wandb_writer()
if wandb_writer:
try:
artifact_name, artifact_version = _get_artifact_name_and_version(Path(load_dir), Path(checkpoint_path))
wandb_tracker_filename = _get_wandb_artifact_tracker_filename(load_dir)
artifact_path = ""
if wandb_tracker_filename.is_file():
artifact_path = wandb_tracker_filename.read_text().strip()
artifact_path = f"{artifact_path}/"
wandb_writer.run.use_artifact(f"{artifact_path}{artifact_name}:{artifact_version}")
except Exception:
print_rank_last(f" failed to find checkpoint {checkpoint_path} in wandb")
\ No newline at end of file
......@@ -59,7 +59,7 @@ def validate_yaml(args, defaults={}):
(args.world_size // args.model_parallel.tensor_model_parallel_size))
args.model_parallel.transformer_pipeline_model_parallel_size = (
args.model_parallel.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
if args.account_for_embedding_in_pipeline_split else
args.model_parallel.pipeline_model_parallel_size
)
# Checks.
......
node021 slots=8
node022 slots=8
\ No newline at end of file
......@@ -154,7 +154,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
path_to_cache=args.data_cache_path,
tokenizer=tokenizer,
......
......@@ -35,8 +35,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
import torch._dynamo
torch._dynamo.config.suppress_errors = True
stimer = StragglerDetector()
......@@ -64,6 +63,15 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
# record stack information for the trace events
trace_alloc_record_context=True)
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
......@@ -128,8 +136,6 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling
)
model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model)
return model
......@@ -150,8 +156,8 @@ def get_batch(data_iterator):
return batch.values()
# define spiky loss as a variation of 20% or more
SPIKY_LOSS_PERC = 0.2
# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
......@@ -187,11 +193,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC),
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
......@@ -252,7 +269,6 @@ def core_gpt_dataset_config_from_args(args):
sequence_length=args.seq_length,
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
......
......@@ -104,8 +104,8 @@ def get_batch(data_iterator):
return batch.values()
# define spiky loss as a variation of 20% or more
SPIKY_LOSS_PERC = 0.2
# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
......@@ -141,11 +141,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC),
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
......@@ -207,7 +218,6 @@ def core_gpt_dataset_config_from_args(args):
sequence_length=args.seq_length,
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
......
......@@ -189,7 +189,6 @@ def train_valid_test_datasets_provider(train_valid_test_num_samples):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
split_preprocessing=retro_config.retro_split_preprocessing,
path_to_cache=args.data_cache_path,
......
......@@ -141,6 +141,8 @@ def model_provider(
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
relative_attention_num_buckets=args.relative_attention_num_buckets,
relative_attention_max_distance=args.relative_attention_max_distance,
add_encoder=add_encoder,
add_decoder=add_decoder,
)
......@@ -226,7 +228,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples: int):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path),
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
path_to_cache=args.data_cache_path,
tokenizer=tokenizer,
......
......@@ -22,42 +22,13 @@ from megatron.core.models.vision.vit_layer_specs import (
get_vit_layer_with_local_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import get_batch_on_this_cp_rank
from megatron.core import mpu
from megatron.core.models.multimodal import context_parallel
from pretrain_gpt import loss_func
def calculate_model_parallel_padding(decoder_seq_len, text_only=False):
args = get_args()
cp_size = args.context_parallel_size
tp_size = args.tensor_model_parallel_size
mp_padding_needed = 0
# TP Comm overlap is performed with combined text+image embeddings.
# text_only flag skips using the full sequence length to calculate padding and uses
# the provided decoder_seq_len
if args.sequence_parallel and args.decoder_tp_comm_overlap and not text_only:
# If TP Comm Overlap is enabled for combined text+image embedding in LM backbone,
# user needs to provide decoder_seq_length with any potential padding needed for SP+CP
assert args.decoder_seq_length is not None, \
"Please provide --decoder-seq-length when using TP Comm overlap for LM backbone"
mp_padding_needed = args.decoder_seq_length - decoder_seq_len
elif args.sequence_parallel or cp_size > 1:
if args.sequence_parallel and cp_size > 1:
# Padding to multiple of tp_size * cp_size*2 when using sequence parallel and context parallel
padding_factor = tp_size * cp_size * 2
elif cp_size > 1:
padding_factor = cp_size * 2
elif args.sequence_parallel:
padding_factor = tp_size
mp_padding_needed = int((decoder_seq_len + padding_factor - 1) // (padding_factor) * (padding_factor)) - decoder_seq_len
args.decoder_seq_length = decoder_seq_len + mp_padding_needed
else:
args.decoder_seq_length = decoder_seq_len
return mp_padding_needed
def model_provider(
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
......@@ -82,6 +53,8 @@ def model_provider(
vision_model_type = "clip"
assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently."
assert not (args.context_parallel_size > 1 and args.pipeline_model_parallel_size > 1), "PP+CP is not yet supported by this script. \
Current mock dataset does not support natively packed sequence dataset required for correct PP comm shapes."
num_image_embeddings = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token,
......@@ -102,7 +75,15 @@ def model_provider(
warnings.warn(
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
)
mp_padding_needed = calculate_model_parallel_padding(decoder_seq_len)
mp_padding_needed = context_parallel.get_padding(
decoder_seq_len,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.sequence_parallel,
args.decoder_tp_comm_overlap,
args.decoder_seq_length
)
args.decoder_seq_length = decoder_seq_len + mp_padding_needed
args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length)
......@@ -183,11 +164,16 @@ def model_provider(
if args.virtual_pipeline_model_parallel_size:
raise NotImplementedError("virtual pipeline model parallelism is not supported yet.")
language_max_sequence_length = args.decoder_seq_length
if args.context_parallel_size > 1:
if args.use_packed_sequence or mp_padding_needed > 0:
# Use THD data format
language_max_sequence_length = args.decoder_seq_length * args.micro_batch_size
model = LLaVAModel(
language_transformer_config=language_transformer_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.decoder_seq_length,
language_max_sequence_length=language_max_sequence_length,
vision_transformer_config=vision_transformer_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token,
......@@ -289,6 +275,7 @@ def _preprocess_data_for_llava(data):
return data
def get_batch(data_iterator):
"""Generate a batch.
......@@ -298,33 +285,6 @@ def get_batch(data_iterator):
Returns:
sample: A data sample with images, tokens, etc.
"""
def _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed):
batch_size = tokens.shape[0]
# Calculate the valid token seq len that LM backbone should compute on
combined_valid_seqlen = tokens.shape[1] + img_seq_len - mp_padding_needed
cu_seqlens = torch.arange(
0, (batch_size + 1) * (combined_valid_seqlen), step=(combined_valid_seqlen), dtype=torch.int32, device=tokens.device)
# Calculate the total padded token seq len
combined_padded_seqlen = tokens.shape[1] + img_seq_len
cu_seqlens_padded = None
qkv_format = 'sbhd'
if cp_size > 1:
# Provide cu_seqlens_<q/kv>_padded for CP support
cu_seqlens_padded = torch.arange(
0, (batch_size + 1) * (combined_padded_seqlen), step=(combined_padded_seqlen), dtype=torch.int32, device=tokens.device)
# CP with padding mask type requires THD format
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=combined_padded_seqlen,
max_seqlen_kv=combined_padded_seqlen,
qkv_format=qkv_format,
)
return packed_seq_params
args = get_args()
cp_size = args.context_parallel_size
# Broadcast data.
......@@ -353,20 +313,41 @@ def get_batch(data_iterator):
args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, 1
)
# Pad to make sure the text sequence can be sharded equally by CP chunks.
mp_padding_needed_for_text = calculate_model_parallel_padding(tokens.shape[1], text_only=True)
if mp_padding_needed_for_text > 0:
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)]
# Image token mask must be supplied before distributed sequence to CP ranks.
image_token_mask = tokens == DEFAULT_IMAGE_TOKEN_INDEX
num_images_per_sample = torch.sum(image_token_mask, dim=-1)
img_seq_len = (num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample).max()
packed_seq_params = _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text)
mp_padding_needed_for_text = context_parallel.get_padding(
tokens.shape[1] + img_seq_len,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.sequence_parallel,
args.decoder_tp_comm_overlap,
args.decoder_seq_length
)
if mp_padding_needed_for_text > 0:
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)]
packed_seq_params = context_parallel.get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text, cp_size, args.use_packed_sequence)
if packed_seq_params.qkv_format == 'thd':
# Reshape from [B,S] to [T,1]
tokens = (
tokens.contiguous()
.view(tokens.shape[0] * tokens.shape[1])
.unsqueeze(0)
)
position_ids = (
position_ids.contiguous()
.view(position_ids.shape[0] * position_ids.shape[1])
.unsqueeze(0)
)
labels = labels.view(labels.shape[0] * labels.shape[1]).unsqueeze(0)
loss_mask = loss_mask.view(
loss_mask.shape[0] * loss_mask.shape[1]
).unsqueeze(0)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank({"tokens": tokens, "position_ids": position_ids})
attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model.
return batch["tokens"], batch["position_ids"], labels, images, loss_mask, attention_mask, image_token_mask, packed_seq_params
return tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params
def forward_step(data_iterator, model: LLaVAModel):
......@@ -384,11 +365,11 @@ def forward_step(data_iterator, model: LLaVAModel):
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, position_ids, labels, images, loss_mask, attention_mask, image_token_mask, packed_seq_params = get_batch(data_iterator)
tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params = get_batch(data_iterator)
timers('batch-generator').stop()
output_tensor, loss_mask = model(
images, tokens, position_ids, attention_mask, labels, loss_mask, image_token_mask=image_token_mask, packed_seq_params=packed_seq_params
images, tokens, position_ids, attention_mask, labels, loss_mask, packed_seq_params=packed_seq_params
)
return output_tensor, partial(loss_func, loss_mask)
......@@ -413,6 +394,12 @@ def add_vlm_extra_args(parser):
group.add_argument("--decoder-tp-comm-overlap", action="store_true", default=False, help="Enables the overlap of "
"Tensor parallel communication and GEMM kernels in Decoder only. "
"Please provide decoder-seq-length when using this feature.")
group.add_argument(
"--use-packed-sequence",
action="store_true",
default=False,
help="Use packed sequence",
)
return parser
......
......@@ -2,3 +2,5 @@
[pytest]
markers =
internal: mark a test as a test to private/internal functions.
flaky: mark flaky tests for LTS environment
flaky_in_dev: mark flaky tests for DEV environment
......@@ -2,6 +2,7 @@ einops
flask-restful
nltk
pytest
pytest_asyncio
pytest-cov
pytest_mock
pytest-random-order
......@@ -11,5 +12,5 @@ wrapt
zarr
wandb
triton==2.1.0
tensorstore==0.1.45
tensorstore!=0.1.46,!=0.1.72
nvidia-modelopt[torch]>=0.19.0; sys_platform != "darwin"
......@@ -2,6 +2,7 @@ einops
flask-restful
nltk
pytest
pytest_asyncio
pytest-cov
pytest_mock
pytest-random-order
......@@ -10,5 +11,6 @@ tiktoken
wrapt
zarr
wandb
tensorstore==0.1.45
tensorstore!=0.1.46
nvidia-modelopt[torch]>=0.19.0; sys_platform != "darwin"
nvidia-resiliency-ext
einops
flask-restful
nltk
pytest
pytest_asyncio
pytest-cov
pytest_mock
pytest-random-order
sentencepiece
tiktoken
wrapt
zarr
wandb
tensorstore!=0.1.46,!=0.1.72
torch
nvidia-modelopt[torch]>=0.19.0; sys_platform != "darwin"
nvidia-resiliency-ext; sys_platform != "darwin"
export TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
export TORCHINDUCTOR_BENCHMARK_FUSION=1
export TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES=1
# export TORCHINDUCTOR_BENCHMARK_KERNEL=1
export TORCHINDUCTOR_MAX_AUTOTUNE=1
#export FLASH_ATTENTION_PRINT_PARAM=1
export TORCHINDUCTOR_CACHE_DIR=./cache
# export USE_AOTRITON_FA=1
# export USE_BSHD=1 # use fa bsdh layout
#for uniq kernel name
#export TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1
mpirun --allow-run-as-root -np 8 ./Llama_pretraining.sh localhost
......@@ -29,10 +29,15 @@ long_description_content_type = "text/markdown"
def req_file(filename, folder="requirements"):
environment = os.getenv("PY_ENV", "pytorch:24.07")
environment = os.getenv("PY_ENV", "pytorch_24.10")
content = []
with open(os.path.join(folder, environment, filename), encoding='utf-8') as f:
content = f.readlines()
content += f.readlines()
with open(os.path.join("megatron", "core", "requirements.txt"), encoding='utf-8') as f:
content += f.readlines()
# you may also want to remove whitespace characters
# Example: `\n` at the end of each line
return [x.strip() for x in content]
......
......@@ -8,6 +8,7 @@ from megatron.legacy.data.orqa_wiki_dataset import get_open_retrieval_wiki_datas
from megatron.legacy.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.legacy.model.biencoder_model import get_model_provider
from megatron.training import get_model
from megatron.core.parallel_state import create_group
from tasks.orqa.unsupervised.nq import get_nq_dataset
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
from tasks.orqa.unsupervised.nq import process_nq_batch
......@@ -116,7 +117,7 @@ class ORQAEvaluator(object):
start_rank = node * device_count
end_rank = (node + 1) * device_count
ranks_list = list(range(start_rank, end_rank))
node_group = torch.distributed.new_group(ranks=ranks_list)
node_group = create_group(ranks=ranks_list, group_desc=f'QA_EVALUATOR_NODE_GROUP')
if node_id == node:
device_start_rank = start_rank
......
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment