Commit 0d99ae1f authored by silencealiang's avatar silencealiang
Browse files

add

parent c271aaae
Pipeline #2498 canceled with stages
......@@ -201,7 +201,6 @@ def validate_args(args, defaults={}):
assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder."
if args.encoder_tensor_model_parallel_size > 0:
assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined."
assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0
assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder."
......@@ -401,6 +400,14 @@ def validate_args(args, defaults={}):
assert not args.use_dist_ckpt, \
'--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet'
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
args.main_grads_dtype = dtype_map[args.main_grads_dtype]
args.main_params_dtype = dtype_map[args.main_params_dtype]
args.exp_avg_dtype = dtype_map[args.exp_avg_dtype]
args.exp_avg_sq_dtype = dtype_map[args.exp_avg_sq_dtype]
if args.fp8_param_gather:
assert args.use_distributed_optimizer, \
'--fp8-param-gather only supported with distributed optimizer'
......@@ -422,7 +429,11 @@ def validate_args(args, defaults={}):
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
if args.accumulate_allreduce_grads_in_fp32:
assert args.main_grads_dtype == torch.float32, \
"--main-grads-dtype can only be fp32 when --accumulate-allreduce-grads-in-fp32 is set"
if not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
......@@ -643,7 +654,7 @@ def validate_args(args, defaults={}):
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton
args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton or args.use_flash_attn_torch
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
......@@ -1366,6 +1377,8 @@ def _add_training_args(parser):
group.add_argument('--use-flash-attn-cutlass', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-torch', action='store_true',
help='use FlashAttention implementation of attention using torch.')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--disable-bias-linear', action='store_false',
......@@ -2078,6 +2091,7 @@ def _add_vision_args(parser):
def _add_moe_args(parser):
group = parser.add_argument_group(title="moe")
# General arguments
group.add_argument('--expert-model-parallel-size', type=int, default=1,
help='Degree of expert model parallelism.')
group.add_argument('--expert-tensor-parallel-size', type=int, default=None,
......@@ -2103,16 +2117,23 @@ def _add_moe_args(parser):
help='Enable overlapping between shared expert computations and dispatcher communications. '
'Without this, the shared epxerts execute after the routed experts. '
'Only effective when moe-shared-expert-intermediate-size is set.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
# Router arguments
group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', 'sinkhorn', 'none'],
choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'],
default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
group.add_argument('--moe-router-topk-limited-devices', type=int, default=None,
help='Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. Default is None, which means no limited devices.')
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true',
help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None,
......@@ -2185,4 +2206,18 @@ def _add_experimental_args(parser):
'the overidden pattern')
group.add_argument('--yaml-cfg', type=str, default=None,
help = 'Config file to add additional arguments')
# Args of precision-aware optimizer
group.add_argument('--use-precision-aware-optimizer', action='store_true',
help='Use the precision-aware optimizer in TransformerEngine, which allows '
'setting the main params and optimizer states to lower precision, such as '
'fp16 and fp8.')
group.add_argument('--main-grads-dtype', default='fp32', choices=['fp32', 'bf16'],
help='Dtype of main grads when enabling precision-aware-optimizer')
group.add_argument('--main-params-dtype', default='fp32', choices=['fp32', 'fp16'],
help='Dtype of main params when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-sq-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg_sq when enabling precision-aware-optimizer')
return parser
File mode changed from 100755 to 100644
......@@ -361,6 +361,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
# Collect rng state across data parallel ranks.
rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY)
# Collect rerun state across all ranks
rerun_state_machine = get_rerun_state_machine()
rerun_state = rerun_state_machine.state_dict(
data_iterator=train_data_iterator, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY
)
# Checkpoint name.
return_base_dir = (ckpt_type != CheckpointType.LEGACY)
checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
......@@ -379,7 +385,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(checkpoint_name)
ensure_directory_exists(optim_checkpoint_name)
optimizer.save_parameter_state(optim_checkpoint_name)
if not optimizer.is_stub_optimizer:
optimizer.save_parameter_state(optim_checkpoint_name)
async_save_request = None
if args.async_save:
......@@ -409,7 +416,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
use_dist_ckpt=ckpt_type != CheckpointType.LEGACY,
iteration=iteration,
optim_sd_kwargs=optim_sd_kwargs,
train_data_iterator=train_data_iterator,
rerun_state=rerun_state,
)
if args.enable_ft_package and ft_client is not None:
......@@ -593,7 +600,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
def generate_state_dict(args, model, optimizer, opt_param_scheduler,
rng_state, use_dist_ckpt=False, iteration=None,
optim_sd_kwargs=None, train_data_iterator=None):
optim_sd_kwargs=None, rerun_state=None):
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
......@@ -614,7 +621,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
model[i].state_dict_for_save_checkpoint())
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
if optimizer is not None and not optimizer.is_stub_optimizer:
state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {}))
if use_dist_ckpt else
optimizer.state_dict())
......@@ -623,10 +630,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
opt_param_scheduler.state_dict()
# Rerun state
rerun_state_machine = get_rerun_state_machine()
state_dict['rerun_state_machine'] = rerun_state_machine.get_checkpoint_state(
train_data_iterator
)
state_dict['rerun_state_machine'] = rerun_state
# RNG states.
if not args.no_save_rng:
......@@ -1136,6 +1140,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
gen_sd_optim = None
gen_sd_opt_param_scheduler = None
# Determine if rerun state will be loaded
if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune):
rerun_state_machine = get_rerun_state_machine()
gen_sd_rerun_state = rerun_state_machine.state_dict(
data_iterator=None, use_dist_ckpt=True
)
else:
gen_sd_rerun_state = None
if ckpt_tp_pp != run_tp_pp:
print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))
# [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model
# will result in key mismatch with loss modules potentially containing parameters, since
# it requires generating a state_dict before loading. Here we hide those modules if present.
......@@ -1145,9 +1160,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
stack.enter_context(m.hide_loss_modules())
load_kwargs['sharded_state_dict'] = generate_state_dict(
args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, train_data_iterator=None
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state
)
# When "--fp8-param-gather" is disabled, this function doesn't modify anything.
fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict'])
......@@ -1230,7 +1245,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim:
try:
# Load state dict.
if not skip_load_to_model_and_opt and optimizer is not None:
if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer:
optimizer.load_state_dict(state_dict['optimizer'])
# Load distributed optimizer's custom parameter state.
......@@ -1268,7 +1283,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rerun state
try:
if 'rerun_state_machine' in state_dict:
get_rerun_state_machine().set_checkpoint_state(state_dict['rerun_state_machine'])
get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
except Exception as e:
print(f"Unable to restore RerunMachine from checkpoint: {e}")
sys.exit()
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -69,14 +69,16 @@ from megatron.core.num_microbatches_calculator import (
from .async_utils import maybe_finalize_async_save
from .utils import (
append_to_progress_log,
calc_params_l2_norm,
check_adlr_autoresume_termination,
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group,
is_last_rank,
print_rank_0,
print_rank_last,
report_memory,
unwrap_model,
append_to_progress_log,
update_use_dist_ckpt,
)
from .global_vars import (
......@@ -86,7 +88,8 @@ from .global_vars import (
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger)
get_one_logger,
)
from . import one_logger_utils
from . import ft_integration
......@@ -135,13 +138,6 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2
# print(f"batch_size: {batch_size}, \
# query_projection_to_hidden_size_ratio: {query_projection_to_hidden_size_ratio}, \
# num_experts_routed_to: {num_experts_routed_to}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# shared_expert_ffn_hidden_size: {shared_expert_ffn_hidden_size}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# ")
return (
expansion_factor
* batch_size
......@@ -219,7 +215,7 @@ def get_start_time_from_progress_log():
def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
......@@ -753,7 +749,7 @@ def setup_model_and_optimizer(model_provider_func,
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -790,10 +786,20 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
......@@ -832,7 +838,6 @@ def train_step(forward_step_func, data_iterator,
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
......@@ -913,6 +918,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# learning rate will be None on ranks without trainable params, so we must gather across mp ranks
learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate)
# Tensorboard values.
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
......@@ -930,12 +937,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
if args.skipped_train_samples > 0:
writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
if wandb_writer:
......@@ -1035,7 +1042,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
# Decoupled_learning_rate should be not None only on first and last pipeline stage.
log_string += f' learning rate: {learning_rate:.6E} |'
if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or
......@@ -1068,7 +1074,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
if report_memory_flag:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
......@@ -1120,10 +1126,10 @@ def enable_forward_pre_hook(model_chunks):
model_chunk.enable_forward_pre_hook()
def disable_forward_pre_hook(model_chunks):
def disable_forward_pre_hook(model_chunks, param_sync=True):
for model_chunk in model_chunks:
assert isinstance(model_chunk, DDP)
model_chunk.disable_forward_pre_hook()
model_chunk.disable_forward_pre_hook(param_sync=param_sync)
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
......@@ -1223,7 +1229,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
prof.stop()
else:
torch.cuda.cudart().cudaProfilerStop()
# Manual garbage collection.
if args.manual_gc:
......@@ -1361,6 +1366,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
pre_hook_enabled = False
should_exit = False
exit_code = 0
......@@ -1414,26 +1420,52 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
if args.rank in [0]:
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
print(p.key_averages(group_by_input_shape=True,
group_by_stack_n=5).table(sort_by="self_cuda_time_total",
row_limit=-1,
max_src_column_width=100,
max_name_column_width=280,
max_shapes_column_width=200))
p.export_chrome_trace("{path}/trace_rank{rank}_step{step}.json".format(
path=args.profile_dir, rank=torch.distributed.get_rank(), step=p.step_num))
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
on_trace_ready=trace_handler)
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
record_shapes=True,
#on_trace_ready=torch.profiler.tensorboard_trace_handler('./torch_prof_data'))
on_trace_ready=trace_handler)
prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
import ctypes
roctracer = ctypes.cdll.LoadLibrary("/opt/dtk/roctracer/lib/libroctracer64.so")
start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
# Run training iterations till done.
while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
......@@ -1456,12 +1488,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
......@@ -1469,23 +1501,41 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
args.curr_iteration = iteration
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
# why is skipped_iter ignored?
# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
......@@ -1499,8 +1549,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch
# Logging.
loss_scale = optimizer.get_loss_scale().item()
if not optimizer.is_stub_optimizer:
loss_scale = optimizer.get_loss_scale().item()
else:
loss_scale = 1.0
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
learning_rate = None
......@@ -1511,11 +1565,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \
......@@ -1523,16 +1577,17 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -1543,6 +1598,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
......@@ -1552,12 +1608,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
if should_exit:
break
......@@ -1569,7 +1625,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
writer.flush()
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if args.use_distributed_optimizer and args.overlap_param_gather:
if pre_hook_enabled:
disable_forward_pre_hook(model)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
......@@ -1701,7 +1757,9 @@ def evaluate(forward_step_func,
timers('evaluate').stop()
timers.log(['evaluate'])
rerun_state_machine.set_mode(rerun_mode)
rerun_state_machine.set_mode(rerun_mode)
return total_loss_dict, collected_non_loss_data, False
......@@ -1869,12 +1927,15 @@ def build_train_valid_test_data_iterators(
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return RerunDataIterator(dataloader)
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return RerunDataIterator(cyclic_iter(dataloader))
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return RerunDataIterator(dataloader, make_iterable=False)
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")
......
......@@ -36,7 +36,11 @@ from megatron.core import DistributedDataParallel as DDP
from megatron.core import mpu
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor
from megatron.core.utils import (
get_batch_on_this_cp_rank,
get_data_parallel_group_if_dtensor,
to_local_if_dtensor,
)
from megatron.legacy.model import Float16Module
from megatron.legacy.model.module import param_is_not_shared
......@@ -90,13 +94,16 @@ def calc_params_l2_norm(model):
# Calculate dense param norm
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
if len(params_data) > 0:
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
else:
norm_2 = torch.tensor([0.0], dtype=torch.float32, device='cuda')
if data_parallel_group is not None:
torch.distributed.all_reduce(norm_2,
......@@ -140,6 +147,41 @@ def average_losses_across_data_parallel_group(losses):
return averaged_losses
def reduce_max_stat_across_model_parallel_group(stat: float) -> float:
"""
Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats.
We need to ensure the logging and writer rank has those values.
This function reduces a stat tensor across the model parallel group.
We use an all_reduce max since the values have already been summed across optimizer ranks where possible
"""
if stat is None:
stat = -1.0
stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.all_reduce(
stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()
)
if stat.item() == -1.0:
return None
else:
return stat.item()
def logical_and_across_model_parallel_group(input: bool) -> bool:
"""
This function gathers a bool value across the model parallel group
"""
if input is True:
input = 1
else:
input = 0
input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(
input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group()
)
return bool(input.item())
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
......@@ -254,39 +296,6 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def get_batch_on_this_cp_rank(batch):
""" Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
args = get_args()
cp_size = args.context_parallel_size
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)],
device="cpu", pin_memory=True).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
return batch
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
......@@ -431,11 +440,11 @@ def get_batch_on_this_tp_rank(data_iterator):
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
......@@ -443,11 +452,11 @@ def get_batch_on_this_tp_rank(data_iterator):
elif mpu.is_pipeline_last_stage():
tokens=None
position_ids=None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
batch = {
'tokens': tokens,
'labels': labels,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment