Commit 99a0c39e authored by xingjinliang's avatar xingjinliang
Browse files

同步最新代码

parent 50fe58fa
Pipeline #2152 passed with stage
File mode changed from 100755 to 100644
......@@ -361,6 +361,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
# Collect rng state across data parallel ranks.
rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY)
# Collect rerun state across all ranks
rerun_state_machine = get_rerun_state_machine()
rerun_state = rerun_state_machine.state_dict(
data_iterator=train_data_iterator, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY
)
# Checkpoint name.
return_base_dir = (ckpt_type != CheckpointType.LEGACY)
checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
......@@ -379,7 +385,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(checkpoint_name)
ensure_directory_exists(optim_checkpoint_name)
optimizer.save_parameter_state(optim_checkpoint_name)
if not optimizer.is_stub_optimizer:
optimizer.save_parameter_state(optim_checkpoint_name)
async_save_request = None
if args.async_save:
......@@ -409,7 +416,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
use_dist_ckpt=ckpt_type != CheckpointType.LEGACY,
iteration=iteration,
optim_sd_kwargs=optim_sd_kwargs,
train_data_iterator=train_data_iterator,
rerun_state=rerun_state,
)
if args.enable_ft_package and ft_client is not None:
......@@ -593,7 +600,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
def generate_state_dict(args, model, optimizer, opt_param_scheduler,
rng_state, use_dist_ckpt=False, iteration=None,
optim_sd_kwargs=None, train_data_iterator=None):
optim_sd_kwargs=None, rerun_state=None):
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
......@@ -614,7 +621,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
model[i].state_dict_for_save_checkpoint())
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
if optimizer is not None and not optimizer.is_stub_optimizer:
state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {}))
if use_dist_ckpt else
optimizer.state_dict())
......@@ -623,10 +630,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
opt_param_scheduler.state_dict()
# Rerun state
rerun_state_machine = get_rerun_state_machine()
state_dict['rerun_state_machine'] = rerun_state_machine.get_checkpoint_state(
train_data_iterator
)
state_dict['rerun_state_machine'] = rerun_state
# RNG states.
if not args.no_save_rng:
......@@ -1136,6 +1140,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
gen_sd_optim = None
gen_sd_opt_param_scheduler = None
# Determine if rerun state will be loaded
if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune):
rerun_state_machine = get_rerun_state_machine()
gen_sd_rerun_state = rerun_state_machine.state_dict(
data_iterator=None, use_dist_ckpt=True
)
else:
gen_sd_rerun_state = None
if ckpt_tp_pp != run_tp_pp:
print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))
# [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model
# will result in key mismatch with loss modules potentially containing parameters, since
# it requires generating a state_dict before loading. Here we hide those modules if present.
......@@ -1145,9 +1160,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
stack.enter_context(m.hide_loss_modules())
load_kwargs['sharded_state_dict'] = generate_state_dict(
args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, train_data_iterator=None
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state
)
# When "--fp8-param-gather" is disabled, this function doesn't modify anything.
fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict'])
......@@ -1230,7 +1245,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim:
try:
# Load state dict.
if not skip_load_to_model_and_opt and optimizer is not None:
if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer:
optimizer.load_state_dict(state_dict['optimizer'])
# Load distributed optimizer's custom parameter state.
......@@ -1268,7 +1283,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rerun state
try:
if 'rerun_state_machine' in state_dict:
get_rerun_state_machine().set_checkpoint_state(state_dict['rerun_state_machine'])
get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
except Exception as e:
print(f"Unable to restore RerunMachine from checkpoint: {e}")
sys.exit()
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -69,14 +69,16 @@ from megatron.core.num_microbatches_calculator import (
from .async_utils import maybe_finalize_async_save
from .utils import (
append_to_progress_log,
calc_params_l2_norm,
check_adlr_autoresume_termination,
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group,
is_last_rank,
print_rank_0,
print_rank_last,
report_memory,
unwrap_model,
append_to_progress_log,
update_use_dist_ckpt,
)
from .global_vars import (
......@@ -86,7 +88,8 @@ from .global_vars import (
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger)
get_one_logger,
)
from . import one_logger_utils
from . import ft_integration
......@@ -135,13 +138,6 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2
# print(f"batch_size: {batch_size}, \
# query_projection_to_hidden_size_ratio: {query_projection_to_hidden_size_ratio}, \
# num_experts_routed_to: {num_experts_routed_to}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# shared_expert_ffn_hidden_size: {shared_expert_ffn_hidden_size}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# ")
return (
expansion_factor
* batch_size
......@@ -219,7 +215,7 @@ def get_start_time_from_progress_log():
def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
......@@ -753,7 +749,7 @@ def setup_model_and_optimizer(model_provider_func,
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -790,10 +786,20 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
......@@ -832,7 +838,6 @@ def train_step(forward_step_func, data_iterator,
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
......@@ -913,6 +918,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# learning rate will be None on ranks without trainable params, so we must gather across mp ranks
learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate)
# Tensorboard values.
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
......@@ -930,12 +937,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
if args.skipped_train_samples > 0:
writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
if wandb_writer:
......@@ -1035,7 +1042,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
# Decoupled_learning_rate should be not None only on first and last pipeline stage.
log_string += f' learning rate: {learning_rate:.6E} |'
if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or
......@@ -1068,7 +1074,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
if report_memory_flag:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
......@@ -1120,10 +1126,10 @@ def enable_forward_pre_hook(model_chunks):
model_chunk.enable_forward_pre_hook()
def disable_forward_pre_hook(model_chunks):
def disable_forward_pre_hook(model_chunks, param_sync=True):
for model_chunk in model_chunks:
assert isinstance(model_chunk, DDP)
model_chunk.disable_forward_pre_hook()
model_chunk.disable_forward_pre_hook(param_sync=param_sync)
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
......@@ -1223,7 +1229,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
prof.stop()
else:
torch.cuda.cudart().cudaProfilerStop()
# Manual garbage collection.
if args.manual_gc:
......@@ -1361,6 +1366,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
pre_hook_enabled = False
should_exit = False
exit_code = 0
......@@ -1434,6 +1440,24 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
import ctypes
roctracer = ctypes.cdll.LoadLibrary("/opt/dtk/roctracer/lib/libroctracer64.so")
start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
# Run training iterations till done.
while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
......@@ -1456,12 +1480,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
......@@ -1469,23 +1493,41 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
args.curr_iteration = iteration
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
# why is skipped_iter ignored?
# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
......@@ -1499,8 +1541,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch
# Logging.
loss_scale = optimizer.get_loss_scale().item()
if not optimizer.is_stub_optimizer:
loss_scale = optimizer.get_loss_scale().item()
else:
loss_scale = 1.0
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
learning_rate = None
......@@ -1511,11 +1557,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \
......@@ -1523,16 +1569,17 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -1543,6 +1590,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
......@@ -1552,12 +1600,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
if should_exit:
break
......@@ -1569,7 +1617,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
writer.flush()
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if args.use_distributed_optimizer and args.overlap_param_gather:
if pre_hook_enabled:
disable_forward_pre_hook(model)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
......@@ -1701,7 +1749,9 @@ def evaluate(forward_step_func,
timers('evaluate').stop()
timers.log(['evaluate'])
rerun_state_machine.set_mode(rerun_mode)
rerun_state_machine.set_mode(rerun_mode)
return total_loss_dict, collected_non_loss_data, False
......@@ -1869,12 +1919,15 @@ def build_train_valid_test_data_iterators(
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return RerunDataIterator(dataloader)
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return RerunDataIterator(cyclic_iter(dataloader))
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return RerunDataIterator(dataloader, make_iterable=False)
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")
......
......@@ -36,7 +36,11 @@ from megatron.core import DistributedDataParallel as DDP
from megatron.core import mpu
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor
from megatron.core.utils import (
get_batch_on_this_cp_rank,
get_data_parallel_group_if_dtensor,
to_local_if_dtensor,
)
from megatron.legacy.model import Float16Module
from megatron.legacy.model.module import param_is_not_shared
......@@ -90,13 +94,16 @@ def calc_params_l2_norm(model):
# Calculate dense param norm
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
if len(params_data) > 0:
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
else:
norm_2 = torch.tensor([0.0], dtype=torch.float32, device='cuda')
if data_parallel_group is not None:
torch.distributed.all_reduce(norm_2,
......@@ -140,6 +147,41 @@ def average_losses_across_data_parallel_group(losses):
return averaged_losses
def reduce_max_stat_across_model_parallel_group(stat: float) -> float:
"""
Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats.
We need to ensure the logging and writer rank has those values.
This function reduces a stat tensor across the model parallel group.
We use an all_reduce max since the values have already been summed across optimizer ranks where possible
"""
if stat is None:
stat = -1.0
stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.all_reduce(
stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()
)
if stat.item() == -1.0:
return None
else:
return stat.item()
def logical_and_across_model_parallel_group(input: bool) -> bool:
"""
This function gathers a bool value across the model parallel group
"""
if input is True:
input = 1
else:
input = 0
input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(
input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group()
)
return bool(input.item())
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
......@@ -254,39 +296,6 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def get_batch_on_this_cp_rank(batch):
""" Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
args = get_args()
cp_size = args.context_parallel_size
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)],
device="cpu", pin_memory=True).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
return batch
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
......@@ -431,11 +440,11 @@ def get_batch_on_this_tp_rank(data_iterator):
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
......@@ -443,11 +452,11 @@ def get_batch_on_this_tp_rank(data_iterator):
elif mpu.is_pipeline_last_stage():
tokens=None
position_ids=None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
batch = {
'tokens': tokens,
'labels': labels,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -91,11 +91,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.fp8)
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention)
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
build_model_context = nullcontext
build_model_context_args = {}
......@@ -128,6 +128,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling
)
#model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment