Commit fa142de0 authored by dongcl's avatar dongcl
Browse files

patch for megatron core0.12.0

parent cf5d3189
Subproject commit 408eb7186a68ba4d30ee6cc8b05b4de6ba702148 Subproject commit d580efc68a9f0dbf1945f834f6f6200cd01d3343
...@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -169,6 +169,15 @@ class CoreAdaptation(MegatronAdaptationABC):
staticmethod, staticmethod,
apply_wrapper=True) apply_wrapper=True)
# reduce_scatter_to_sequence_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_scatter_to_sequence_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# reduce_from_tensor_model_parallel_region
MegatronAdaptation.register('megatron.core.tensor_parallel.mappings.reduce_from_tensor_model_parallel_region',
torch._dynamo.disable,
apply_wrapper=True)
# flux # flux
if int(os.getenv("USE_FLUX_OVERLAP", "0")): if int(os.getenv("USE_FLUX_OVERLAP", "0")):
from ..core.tensor_parallel.layers import ( from ..core.tensor_parallel.layers import (
...@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -189,6 +198,7 @@ class CoreAdaptation(MegatronAdaptationABC):
from ..training.initialize import _initialize_distributed from ..training.initialize import _initialize_distributed
from ..training.initialize import _compile_dependencies from ..training.initialize import _compile_dependencies
from ..training.training import train from ..training.training import train
from ..training.initialize import _set_random_seed
MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer', MegatronAdaptation.register('megatron.training.tokenizer.tokenizer.build_tokenizer',
build_tokenizer) build_tokenizer)
...@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC): ...@@ -199,6 +209,10 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation.register('megatron.training.initialize._compile_dependencies', MegatronAdaptation.register('megatron.training.initialize._compile_dependencies',
_compile_dependencies) _compile_dependencies)
# 添加固定seed
MegatronAdaptation.register('megatron.training.initialize._set_random_seed',
_set_random_seed)
# add trace_handler # add trace_handler
MegatronAdaptation.register('megatron.training.training.train', MegatronAdaptation.register('megatron.training.training.train',
train) train)
......
...@@ -7,6 +7,7 @@ from megatron.training import get_args ...@@ -7,6 +7,7 @@ from megatron.training import get_args
from megatron.core import parallel_state from megatron.core import parallel_state
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.pipeline_parallel.schedules import set_current_microbatch
from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.utils import ( from megatron.core.utils import (
get_attr_wrapped_model, get_attr_wrapped_model,
...@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import ( ...@@ -28,19 +29,6 @@ from megatron.core.pipeline_parallel.schedules import (
from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func from .combined_1f1b import VppContextManager, forward_backward_step, set_streams, wrap_forward_func
def set_current_microbatch(model, microbatch_id):
"""Set the current microbatch."""
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
for layer in decoder.layers:
layer.current_microbatch = microbatch_id
def get_pp_rank_microbatches( def get_pp_rank_microbatches(
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage, forward_only=False
): ):
......
...@@ -16,35 +16,6 @@ from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerL ...@@ -16,35 +16,6 @@ from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerL
class TransformerLayer(MegatronCoreTransformerLayer): class TransformerLayer(MegatronCoreTransformerLayer):
def _callable_wrapper(
self, is_forward, func, stream, event, *args, skip_detach=False, **kwargs
):
"""
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
torch.cuda.nvtx.range_push(func.__name__)
event.wait(stream)
with torch.cuda.stream(stream):
outputs = func(*args, **kwargs)
event.record(stream)
if skip_detach:
torch.cuda.nvtx.range_pop()
return outputs
detached_output_tensors = []
if not is_forward:
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
for tensor in outputs:
if tensor is None:
detached_output_tensors.append(None)
elif tensor.dtype.is_floating_point:
detached_output_tensors.append(tensor.detach().requires_grad_(True))
else:
detached_output_tensors.append(tensor.detach())
torch.cuda.nvtx.range_pop()
return outputs, detached_output_tensors
def forward( def forward(
self, self,
hidden_states: Tensor, hidden_states: Tensor,
...@@ -123,6 +94,12 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -123,6 +94,12 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual = hidden_states residual = hidden_states
# Optional Input Layer norm # Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
self.input_layernorm, hidden_states
)
else:
input_layernorm_output = self.input_layernorm(hidden_states) input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
...@@ -138,6 +115,13 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -138,6 +115,13 @@ class TransformerLayer(MegatronCoreTransformerLayer):
sequence_len_offset=sequence_len_offset, sequence_len_offset=sequence_len_offset,
) )
if self.recompute_input_layernorm:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self.input_layernorm_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself # TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module? # inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
...@@ -178,6 +162,12 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -178,6 +162,12 @@ class TransformerLayer(MegatronCoreTransformerLayer):
) )
# Optional Layer norm post the cross-attention. # Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
self.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
probs, routing_map = self.mlp.router(pre_mlp_layernorm_output) probs, routing_map = self.mlp.router(pre_mlp_layernorm_output)
...@@ -249,6 +239,16 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -249,6 +239,16 @@ class TransformerLayer(MegatronCoreTransformerLayer):
if shared_expert_output is not None: if shared_expert_output is not None:
output += shared_expert_output output += shared_expert_output
mlp_output_with_bias = (output, mlp_bias) mlp_output_with_bias = (output, mlp_bias)
if self.recompute_pre_mlp_layernorm:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute(
mlp_output_with_bias[0]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout mlp_output_with_bias, residual, self.hidden_dropout
......
...@@ -105,7 +105,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -105,7 +105,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process # Call the init process
init_process_group_kwargs = { init_process_group_kwargs = {
'backend' : args.distributed_backend, 'backend': args.distributed_backend,
'world_size': args.world_size, 'world_size': args.world_size,
'rank': args.rank, 'rank': args.rank,
'init_method': args.dist_url, 'init_method': args.dist_url,
...@@ -149,3 +149,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -149,3 +149,35 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
f"> initialized pipeline model parallel with size " f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}" f"{mpu.get_pipeline_model_parallel_world_size()}"
) )
def _set_random_seed(
seed_: int,
data_parallel_random_init: bool = False,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Set random seed for reproducability."""
args = get_args()
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(
seed, te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng
)
if args.reproduce:
assert (args.attention_dropout > 0) is False, f"To utilize the reproduction function, args.attention_dropout = {args.attention_dropout} must be set to 0."
assert (args.hidden_dropout > 0) is False, f"To utilize the reproduction function, args.hidden_dropout = {args.hidden_dropout} must be set to 0."
torch.backends.cudnn.deterministic = True # 设置cudnn后端为确定性算法
torch.backends.cudnn.benchmark = False # 固定卷积算法
torch.use_deterministic_algorithms(True) # 使用torch的deterministic算子 避免不确定性
else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
...@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import ( ...@@ -9,8 +9,10 @@ from megatron.training.tokenizer.tokenizer import (
_Llama2Tokenizer, _Llama2Tokenizer,
CustomTikTokenizer, CustomTikTokenizer,
_NullTokenizer, _NullTokenizer,
_NullMultimodalTokenizer,
_vocab_size_with_padding _vocab_size_with_padding
) )
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
def build_tokenizer(args, **kwargs): def build_tokenizer(args, **kwargs):
...@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs): ...@@ -92,7 +94,11 @@ def build_tokenizer(args, **kwargs):
args.tokenizer_prompt_format, args.tokenizer_prompt_format,
args.special_tokens, args.special_tokens,
args.image_tag_type, args.image_tag_type,
args.force_system_message,
) )
elif args.tokenizer_type == 'NullMultimodalTokenizer':
assert args.vocab_size is not None
tokenizer = _NullMultimodalTokenizer(args.vocab_size)
elif args.tokenizer_type == "DeepSeekV2Tokenizer": elif args.tokenizer_type == "DeepSeekV2Tokenizer":
tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size) tokenizer = _DeepSeekV2Tokenizer(args.tokenizer_model, args.extra_vocab_size)
args.padded_vocab_size = tokenizer.vocab_size args.padded_vocab_size = tokenizer.vocab_size
......
...@@ -53,18 +53,9 @@ from megatron.training.training import ( ...@@ -53,18 +53,9 @@ from megatron.training.training import (
stimer = StragglerDetector() stimer = StragglerDetector()
def train( def train(forward_step_func, model, optimizer, opt_param_scheduler,
forward_step_func, train_data_iterator, valid_data_iterator,
model, process_non_loss_data_func, config, checkpointing_context, non_loss_data_func):
optimizer,
opt_param_scheduler,
train_data_iterator,
valid_data_iterator,
process_non_loss_data_func,
config,
checkpointing_context,
non_loss_data_func,
):
"""Training function: run train_step desired number of times, run validation, checkpoint.""" """Training function: run train_step desired number of times, run validation, checkpoint."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -74,10 +65,7 @@ def train( ...@@ -74,10 +65,7 @@ def train(
try: try:
from workload_inspector.utils.webserver import run_server from workload_inspector.utils.webserver import run_server
import threading import threading
threading.Thread(target=run_server, daemon=True, args=(torch.distributed.get_rank(), )).start()
threading.Thread(
target=run_server, daemon=True, args=(torch.distributed.get_rank(),)
).start()
except ModuleNotFoundError: except ModuleNotFoundError:
print_rank_0("workload inspector module not found.") print_rank_0("workload inspector module not found.")
...@@ -100,17 +88,11 @@ def train( ...@@ -100,17 +88,11 @@ def train(
rerun_state_machine.current_iteration = iteration rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training. # Track E2E metrics at the start of training.
one_logger_utils.on_train_start( one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
iteration=iteration, train_samples=args.train_samples, seq_length=args.seq_length,
consumed_train_samples=args.consumed_train_samples, train_iters=args.train_iters, save=args.save, async_save=args.async_save,
train_samples=args.train_samples,
seq_length=args.seq_length,
train_iters=args.train_iters,
save=args.save,
async_save=args.async_save,
log_throughput=args.log_throughput, log_throughput=args.log_throughput,
num_floating_point_operations_so_far=args.num_floating_point_operations_so_far, num_floating_point_operations_so_far=args.num_floating_point_operations_so_far)
)
num_floating_point_operations_so_far = args.num_floating_point_operations_so_far num_floating_point_operations_so_far = args.num_floating_point_operations_so_far
...@@ -118,10 +100,9 @@ def train( ...@@ -118,10 +100,9 @@ def train(
config.grad_scale_func = optimizer.scale_loss config.grad_scale_func = optimizer.scale_loss
config.timers = timers config.timers = timers
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce: if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, ( assert config.no_sync_func is None, \
'When overlap_grad_reduce is True, config.no_sync_func must be None; ' ('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce' 'a custom no_sync_func is not supported when overlapping grad-reduce')
)
config.no_sync_func = [model_chunk.no_sync for model_chunk in model] config.no_sync_func = [model_chunk.no_sync for model_chunk in model]
if len(model) == 1: if len(model) == 1:
config.no_sync_func = config.no_sync_func[0] config.no_sync_func = config.no_sync_func[0]
...@@ -145,9 +126,8 @@ def train( ...@@ -145,9 +126,8 @@ def train(
if args.manual_gc: if args.manual_gc:
# Disable the default garbage collector and perform the collection manually. # Disable the default garbage collector and perform the collection manually.
# This is to align the timing of garbage collection across ranks. # This is to align the timing of garbage collection across ranks.
assert ( assert args.manual_gc_interval >= 0, \
args.manual_gc_interval >= 0 'Manual garbage collection interval should be larger than or equal to 0'
), 'Manual garbage collection interval should be larger than or equal to 0'
gc.disable() gc.disable()
gc.collect() gc.collect()
...@@ -157,13 +137,10 @@ def train( ...@@ -157,13 +137,10 @@ def train(
world = torch.distributed.get_world_size() world = torch.distributed.get_world_size()
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
mmcnt = args.straggler_minmax_count mmcnt = args.straggler_minmax_count
stimer.configure( stimer.configure(world, rank,
world, mmcnt = mmcnt,
rank, enabled = not args.disable_straggler_on_startup,
mmcnt=mmcnt, port = args.straggler_ctrlr_port)
enabled=not args.disable_straggler_on_startup,
port=args.straggler_ctrlr_port,
)
num_floating_point_operations_since_last_log_event = 0.0 num_floating_point_operations_since_last_log_event = 0.0
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
...@@ -171,10 +148,10 @@ def train( ...@@ -171,10 +148,10 @@ def train(
eval_iterations = 0 eval_iterations = 0
def get_e2e_base_metrics(): def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.""" """Get base metrics values for one-logger to calculate E2E tracking metrics.
num_floating_point_operations_since_current_train_start = ( """
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
)
return { return {
'iteration': iteration, 'iteration': iteration,
'train_duration': timers('interval-time').active_time(), 'train_duration': timers('interval-time').active_time(),
...@@ -184,7 +161,7 @@ def train( ...@@ -184,7 +161,7 @@ def train(
'num_floating_point_operations_so_far': num_floating_point_operations_so_far, 'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples, 'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size, 'world_size': args.world_size,
'seq_length': args.seq_length, 'seq_length': args.seq_length
} }
# Cache into one-logger for callback. # Cache into one-logger for callback.
if one_logger: if one_logger:
...@@ -192,11 +169,7 @@ def train( ...@@ -192,11 +169,7 @@ def train(
one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics) one_logger.store_set('get_e2e_base_metrics', get_e2e_base_metrics)
prof = None prof = None
if ( if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
args.profile
and torch.distributed.get_rank() in args.profile_ranks
and args.use_pytorch_profiler
):
def trace_handler(p): def trace_handler(p):
from pathlib import Path from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True) Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
...@@ -242,9 +215,8 @@ def train( ...@@ -242,9 +215,8 @@ def train(
pre_hook_enabled = False pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic. # Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None: if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas( assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
model, cross_check=True "Parameter hashes not matching across DP replicas"
), "Parameter hashes not matching across DP replicas"
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...") print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
...@@ -270,20 +242,14 @@ def train( ...@@ -270,20 +242,14 @@ def train(
# to make sure training configuration is still valid. # to make sure training configuration is still valid.
update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True) update_num_microbatches(args.consumed_train_samples, consistency_check=False, verbose=True)
if get_num_microbatches() != num_microbatches and iteration != 0: if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, ( assert get_num_microbatches() > num_microbatches, \
f"Number of microbatches should be increasing due to batch size rampup; " (f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}" f"instead going from {num_microbatches} to {get_num_microbatches()}")
)
if args.save is not None: if args.save is not None:
save_checkpoint_and_time( save_checkpoint_and_time(iteration, model, optimizer,
iteration,
model,
optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, checkpointing_context, train_data_iterator=train_data_iterator)
train_data_iterator=train_data_iterator,
)
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
...@@ -292,9 +258,9 @@ def train( ...@@ -292,9 +258,9 @@ def train(
# Dummy train_step to fast forward train_data_iterator. # Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator) dummy_train_step(train_data_iterator)
iteration += 1 iteration += 1
batch_size = ( batch_size = mpu.get_data_parallel_world_size() * \
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() args.micro_batch_size * \
) get_num_microbatches()
args.consumed_train_samples += batch_size args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size args.skipped_train_samples += batch_size
continue continue
...@@ -302,28 +268,19 @@ def train( ...@@ -302,28 +268,19 @@ def train(
# Run training step. # Run training step.
args.curr_iteration = iteration args.curr_iteration = iteration
ft_integration.on_training_step_start() ft_integration.on_training_step_start()
( loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
loss_dict, train_step(forward_step_func,
skipped_iter, train_data_iterator,
should_checkpoint,
should_exit,
exit_code,
grad_norm,
num_zeros_in_grad,
) = train_step(
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config
)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(
iteration,
model, model,
optimizer, optimizer,
opt_param_scheduler,
config)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, checkpointing_context, train_data_iterator=train_data_iterator)
train_data_iterator=train_data_iterator,
)
if should_exit: if should_exit:
break break
...@@ -346,13 +303,12 @@ def train( ...@@ -346,13 +303,12 @@ def train(
pre_hook_enabled = True pre_hook_enabled = True
iteration += 1 iteration += 1
batch_size = ( batch_size = mpu.get_data_parallel_world_size() * \
mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() args.micro_batch_size * \
) get_num_microbatches()
args.consumed_train_samples += batch_size args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = ( num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_global_batch_size() - get_current_running_global_batch_size() get_current_running_global_batch_size())
)
if args.decrease_batch_size_if_needed: if args.decrease_batch_size_if_needed:
assert num_skipped_samples_in_batch >= 0 assert num_skipped_samples_in_batch >= 0
else: else:
...@@ -378,22 +334,16 @@ def train( ...@@ -378,22 +334,16 @@ def train(
decoupled_learning_rate = param_group['lr'] decoupled_learning_rate = param_group['lr']
else: else:
learning_rate = param_group['lr'] learning_rate = param_group['lr']
report_memory_flag = training_log( report_memory_flag = training_log(loss_dict, total_loss_dict,
loss_dict,
total_loss_dict,
learning_rate, learning_rate,
decoupled_learning_rate, decoupled_learning_rate,
iteration, iteration, loss_scale,
loss_scale, report_memory_flag, skipped_iter,
report_memory_flag, grad_norm, params_norm, num_zeros_in_grad)
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad,
)
# Evaluation. # Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop() timers('interval-time').stop()
if should_disable_forward_pre_hook(args): if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model) disable_forward_pre_hook(model)
...@@ -403,18 +353,11 @@ def train( ...@@ -403,18 +353,11 @@ def train(
gc.collect() gc.collect()
prefix = f'iteration {iteration}' prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True) timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results( evaluate_and_print_results(prefix, forward_step_func,
prefix, valid_data_iterator, model,
forward_step_func, iteration, process_non_loss_data_func,
valid_data_iterator, config, verbose=False, write_to_tensorboard=True,
model, non_loss_data_func=non_loss_data_func)
iteration,
process_non_loss_data_func,
config,
verbose=False,
write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func,
)
eval_duration += timers('eval-time').elapsed() eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters eval_iterations += args.eval_iters
timers('eval-time').stop() timers('eval-time').stop()
...@@ -430,25 +373,13 @@ def train( ...@@ -430,25 +373,13 @@ def train(
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations. # Some of these only happen at specific iterations.
post_training_step_callbacks( post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
model, num_floating_point_operations_since_last_log_event)
optimizer,
opt_param_scheduler,
iteration,
prof,
num_floating_point_operations_since_last_log_event,
)
# Checkpoint and decide whether to exit. # Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit( should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
model,
optimizer,
opt_param_scheduler,
iteration,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, checkpointing_context, train_data_iterator)
train_data_iterator,
)
if should_exit: if should_exit:
break break
...@@ -477,7 +408,6 @@ def train( ...@@ -477,7 +408,6 @@ def train(
if wandb_writer: if wandb_writer:
wandb_writer.finish() wandb_writer.finish()
ft_integration.shutdown() ft_integration.shutdown()
one_logger_utils.finish()
sys.exit(exit_code) sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far return iteration, num_floating_point_operations_so_far
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment