Commit fa142de0 authored by dongcl's avatar dongcl
Browse files

patch for megatron core0.12.0

parent cf5d3189
Subproject commit 408eb7186a68ba4d30ee6cc8b05b4de6ba702148
Subproject commit d580efc68a9f0dbf1945f834f6f6200cd01d3343
......@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod,
apply_wrapper=True)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel.layers import (
......@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies
from ..training.training import train
from ..training.initialize import _set_random_seed
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer)
......@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies)
# 添加固定seed
MegatronAdaptation.register('megatron.training.initialize._set_random_seed',
_set_random_seed)
# add trace_handler
MegatronAdaptation.register('megatron.training.training.train',
train)
......
......@@ -7,6 +7,7 @@ from megatron.training import get_args
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import (
get_attr_wrapped_model,
......@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func
def set_current_microbatch(model, microbatch_id):
"""Set the current microbatch."""
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
for layer in decoder.layers:
layer.current_microbatch = microbatch_id
def get_pp_rank_microbatches(
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False
):
......
......@@ -16,35 +16,6 @@ from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerL
class TransformerLayer(MegatronCoreTransformerLayer):
def _callable_wrapper(
self, is_forward, func, stream, event, *args, skip_detach=False, **kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch.cuda.nvtx.range_push(func.__name__)
event.wait(stream)
with torch.cuda.stream(stream):
outputs = func(*args, **kwargs)
event.record(stream)
if skip_detach:
torch.cuda.nvtx.range_pop()
return outputs
detached_output_tensors = []
if not is_forward:
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
for tensor in outputs:
if tensor is None:
detached_output_tensors.append(None)
elif tensor.dtype.is_floating_point:
detached_output_tensors.append(tensor.detach().requires_grad_(True))
else:
detached_output_tensors.append(tensor.detach())
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
def forward(
self,
hidden_states: Tensor,
......@@ -123,7 +94,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual = hidden_states
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
......@@ -138,6 +115,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset=sequence_len_offset,
)
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
......@@ -178,7 +162,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
tokens_per_expert = self.mlp.token_dispatcher.meta_prepare(
......@@ -249,6 +239,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if shared_expert_output is not None:
output += shared_expert_output
mlp_output_with_bias = (output, mlp_bias)
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
......
......@@ -105,7 +105,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
init_process_group_kwargs = {
'backend' : args.distributed_backend,
'backend': args.distributed_backend,
'world_size': args.world_size,
'rank': args.rank,
'init_method': args.dist_url,
......@@ -149,3 +149,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
def _set_random_seed(
seed_: int,
data_parallel_random_init: bool = False,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Set random seed for reproducability."""
args = get_args()
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(
seed, te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng
)
if args.reproduce:
assert (args.attention_dropout > 0) is False, f"To utilize the reproduction function, args.attention_dropout = {args.attention_dropout} must be set to 0."
assert (args.hidden_dropout > 0) is False, f"To utilize the reproduction function, args.hidden_dropout = {args.hidden_dropout} must be set to 0."
torch.backends.cudnn.deterministic = True # 设置cudnn后端为确定性算法
torch.backends.cudnn.benchmark = False # 固定卷积算法
torch.use_deterministic_algorithms(True) # 使用torch的deterministic算子 避免不确定性
else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
......@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer,
CustomTikTokenizer,
_NullTokenizer,
_NullMultimodalTokenizer,
_vocab_size_with_padding
)
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
def build_tokenizer(args, **kwargs):
......@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args.tokenizer_prompt_format,
args.special_tokens,
args.image_tag_type,
args.force_system_message,
)
elif args.tokenizer_type == 'NullMultimodalTokenizer':
assert args.vocab_size is not None
tokenizer = _NullMultimodalTokenizer(args.vocab_size)
elif args.tokenizer_type == "DeepSeekV2Tokenizer":
tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size)
args.padded_vocab_size = tokenizer.vocab_size
......
......@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer = StragglerDetector()
def train(
forward_step_func,
model,
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
"""Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args()
timers = get_timers()
......@@ -74,10 +65,7 @@ def train(
try:
from workload_inspector.utils.webserver import run_server
import threading
threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
threading.Thread(target=run_server, daemon=True, args=(torch.distributed.get_rank(), )).start()
except ModuleNotFoundError:
print_rank_0("workload inspector module not found.")
......@@ -100,17 +88,11 @@ def train(
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(
iteration=iteration,
consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples,
seq_length=args.seq_length,
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far,
)
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
train_samples=args.train_samples, seq_length=args.seq_length,
train_iters=args.train_iters, save=args.save, async_save=args.async_save,
log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
......@@ -118,10 +100,9 @@ def train(
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, (
'When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce'
)
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1:
config.no_sync_func = config.no_sync_func[0]
......@@ -145,9 +126,8 @@ def train(
if args.manual_gc:
# Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks.
assert (
args.manual_gc_interval >= 0
), 'Manual garbage collection interval should be larger than or equal to 0'
assert args.manual_gc_interval >= 0, \
'Manual garbage collection interval should be larger than or equal to 0'
gc.disable()
gc.collect()
......@@ -157,13 +137,10 @@ def train(
world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count
stimer.configure(
world,
rank,
mmcnt=mmcnt,
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
stimer.configure(world, rank,
mmcnt = mmcnt,
enabled = not args.disable_straggler_on_startup,
port = args.straggler_ctrlr_port)
num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches()
......@@ -171,10 +148,10 @@ def train(
eval_iterations = 0
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics."""
num_floating_point_operations_since_current_train_start = (
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
......@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
'seq_length': args.seq_length,
'seq_length': args.seq_length
}
# Cache into one-logger for callback.
if one_logger:
......@@ -192,11 +169,7 @@ def train(
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None
if (
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
......@@ -242,9 +215,8 @@ def train(
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(
model, cross_check=True
), "Parameter hashes not matching across DP replicas"
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
......@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, (
f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}"
)
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None:
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
......@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
......@@ -302,28 +268,19 @@ def train(
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
(
loss_dict,
skipped_iter,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(
iteration,
model,
optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator=train_data_iterator,
)
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
......@@ -346,13 +303,12 @@ def train(
pre_hook_enabled = True
iteration += 1
batch_size = (
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches()
)
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (
get_current_global_batch_size() - get_current_running_global_batch_size()
)
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0
else:
......@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate = param_group['lr']
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(
loss_dict,
total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration,
loss_scale,
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop()
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
......@@ -403,18 +353,11 @@ def train(
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(
prefix,
forward_step_func,
valid_data_iterator,
model,
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(
model,
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(
model,
optimizer,
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far,
checkpointing_context,
train_data_iterator,
)
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
if should_exit:
break
......@@ -477,7 +408,6 @@ def train(
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment