"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "1359d6cf3bbc925432d2b2f060e17a39983d7ba9"
Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from typing import Callable
import torch
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.nn import QuantModuleRegistry
from modelopt.torch.quantization.nn.modules.quant_linear import _QuantLinear
has_nvidia_modelopt = True
except Exception:
has_nvidia_modelopt = False
class Linear(torch.nn.Linear):
"""Local Linear impl as a replacement of TELinear."""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config
self._return_bias = skip_bias_add and bias
if skip_weight_param_allocation:
raise ValueError('torch.nn.Linear layers do not support skip_weight_param_allocation')
super().__init__(
in_features=input_size, out_features=output_size, bias=bias, dtype=config.params_dtype
)
for param in self.parameters():
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, 'allreduce', self.config.expert_model_parallel_size == 1)
else:
# Reduce the gradient on DP group
setattr(param, 'allreduce', True)
setattr(param, 'sequence_parallel', self.config.sequence_parallel)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
for k, v in state_dict.items():
if "_amax" in k or "_scale" in k:
if v.ndim == 0:
state_dict[k] = v.view(1)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
state_dict, prefix, sharded_offsets=sharded_offsets
)
return sharded_state_dict
def forward(self, x):
"""Forward."""
out = super().forward(x)
if self._return_bias:
return out
return out, None
if has_nvidia_modelopt:
QuantModuleRegistry.register({Linear: Linear.__class__.__name__})(_QuantLinear)
class RealQuantTransformerLayer(TransformerLayer):
"""Real quantization transformer layer base class.
This base class iniitialize the default TransformerLayer and immediately
perform weight-only real quantization via TensorRT Model Optimizer.
All linear weights (Linear, ColumnParallelLinear, RowParallelLinear) picked
up will be replaced with low-bit data type (default torch.uint8). If sub-byte
real_quant_cfg is used, the weight shape will further be half.
This module cannot be trained (all parameters frozen).
"""
verbose: bool = False
real_quant_cfg: str = "None"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if has_nvidia_modelopt and self.real_quant_cfg != "None":
REAL_QUANT_CFG_CHOICES = {
"fp8_real_quant": mtq.FP8_PER_TENSOR_REAL_QUANT_CFG,
"fp8_blockwise_real_quant": mtq.FP8_BLOCKWISE_REAL_QUANT_CFG,
}
mtq_cfg = REAL_QUANT_CFG_CHOICES.get(self.real_quant_cfg, None)
if mtq_cfg is None:
raise ValueError(
"RealQuantTransformerLayer does not support {}".format(self.real_quant_cfg)
)
self._collect_original_tensor_info()
mtq.quantize(self, mtq_cfg)
delattr(self, "_modelopt_state")
# Freeze all parameters since the real-quant linears cannot be trained.
for param in self.parameters():
param.requires_grad = False
if self.verbose:
self._report_quantize_tensor_info()
def _collect_original_tensor_info(self):
self._original_tensor_info = {}
for k, v in self.state_dict().items():
if isinstance(v, torch.Tensor):
self._original_tensor_info[k] = (str(v.dtype), str(v.shape))
def _report_quantize_tensor_info(self):
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
for k, v in self.state_dict().items():
if not isinstance(v, torch.Tensor):
continue
original_dtype, original_shape = self._original_tensor_info.get(k, ("-", "-"))
print(
"{:<64} {:<16} {:<32} {:<16} {:<32}".format(
k, original_dtype, original_shape, str(v.dtype), str(v.shape)
)
)
torch.distributed.barrier()
class FP8WeightTransformerLayer(RealQuantTransformerLayer):
"""FP8 weight transformer layer."""
real_quant_cfg: str = "fp8_real_quant"
class BlockwiseFP8WeightTransformerLayer(RealQuantTransformerLayer):
"""Blockwise FP8 weight transformer layer."""
real_quant_cfg: str = "fp8_blockwise_real_quant"
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_mamba_stack_modelopt_spec(
local_core_attention: bool = False, remap_te_layernorm: bool = False
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map = {}
transformer_state_dict_keys_map = {}
if remap_te_layernorm:
mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'}
transformer_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
mamba_layer = ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
norm=TENorm,
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=ColumnParallelLinear, out_proj=RowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=mamba_state_dict_keys_map,
),
)
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
attention_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
mlp_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer
),
)
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import datetime
import inspect import inspect
import logging import logging
import math import math
import os import os
import random import random
import re
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Union from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Union
...@@ -12,6 +14,9 @@ from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Un ...@@ -12,6 +14,9 @@ from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Un
import numpy as np import numpy as np
import torch import torch
import megatron.core.parallel_state as mpu
from megatron.core.dist_checkpointing.mapping import ShardedObject
"""DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE. """DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE.
The rerun state machine implementation in this file is alpha-level code to help The rerun state machine implementation in this file is alpha-level code to help
...@@ -34,6 +39,7 @@ EXIT_CODE_RESUME_TO_DISAMBIGUATE: int = 16 ...@@ -34,6 +39,7 @@ EXIT_CODE_RESUME_TO_DISAMBIGUATE: int = 16
EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17 EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17
SerializableStateType = Union[list, dict] SerializableStateType = Union[list, dict]
DataIteratorArgType = Optional[Union["RerunDataIterator", list["RerunDataIterator"]]]
class Caller(NamedTuple): class Caller(NamedTuple):
...@@ -105,6 +111,17 @@ class RerunState(Enum): ...@@ -105,6 +111,17 @@ class RerunState(Enum):
RERUNNING_AGAIN_FROM_CHECKPOINT = 5 RERUNNING_AGAIN_FROM_CHECKPOINT = 5
class RerunValidationStatus(str, Enum):
"""Enum representing the status of a record in the tracker log file"""
RERUN_DISABLED = 'rerun_disabled'
INITIAL_RUN = 'initial_run'
FIRST_RERUN_NOT_REPRODUCIBLE = 'first_rerun_not_reproducible'
FIRST_RERUN_REPRODUCIBLE = "first_rerun_reproducible"
SECOND_RERUN_NOT_REPRODUCIBLE = "second_rerun_not_reproducible"
SECOND_RERUN_REPRODUCIBLE = "second_rerun_reproducible"
COMPARISON_MATCH: float = 0.0 COMPARISON_MATCH: float = 0.0
COMPARISON_MISMATCH: float = math.inf COMPARISON_MISMATCH: float = math.inf
...@@ -121,6 +138,7 @@ class RerunStateMachine: ...@@ -121,6 +138,7 @@ class RerunStateMachine:
state_restore_func: optional function to restore the state saved by state_save_func. state_restore_func: optional function to restore the state saved by state_save_func.
mode: operating mode for the rerun state machine, default is disabled. mode: operating mode for the rerun state machine, default is disabled.
error_injector: optional result injection engine, default is no result injection. error_injector: optional result injection engine, default is no result injection.
result_rejected_tracker_filename: optional name of file tracking `result rejected` events.
Example usage: Example usage:
...@@ -170,6 +188,7 @@ class RerunStateMachine: ...@@ -170,6 +188,7 @@ class RerunStateMachine:
state_restore_func: Optional[Callable[[SerializableStateType], None]] = None, state_restore_func: Optional[Callable[[SerializableStateType], None]] = None,
mode: RerunMode = RerunMode.DISABLED, mode: RerunMode = RerunMode.DISABLED,
error_injector: Optional["RerunErrorInjector"] = None, error_injector: Optional["RerunErrorInjector"] = None,
result_rejected_tracker_filename: Optional[str] = None,
) -> None: ) -> None:
self.mode: RerunMode = mode self.mode: RerunMode = mode
self.state: RerunState = RerunState.NOT_RUNNING_YET self.state: RerunState = RerunState.NOT_RUNNING_YET
...@@ -192,6 +211,18 @@ class RerunStateMachine: ...@@ -192,6 +211,18 @@ class RerunStateMachine:
self.suspicious_node: str = None self.suspicious_node: str = None
self.suspicious_device: int = None self.suspicious_device: int = None
# Keep track of `result_rejected` events.
# Make sure the file can be written to and abort if not.
self.result_rejected_tracker_filename = result_rejected_tracker_filename
if self.result_rejected_tracker_filename is not None:
try:
with open(self.result_rejected_tracker_filename, 'a'):
pass
except Exception as e:
raise RuntimeError(
f"RerunStateMachine result validation log cannot be appended to! ({e})"
)
self.saved_state: Optional[SerializableStateType] = None self.saved_state: Optional[SerializableStateType] = None
self.state_save_func: Optional[Callable[[], SerializableStateType]] = state_save_func self.state_save_func: Optional[Callable[[], SerializableStateType]] = state_save_func
self.state_restore_func: Optional[Callable[[SerializableStateType], None]] = ( self.state_restore_func: Optional[Callable[[SerializableStateType], None]] = (
...@@ -199,16 +230,19 @@ class RerunStateMachine: ...@@ -199,16 +230,19 @@ class RerunStateMachine:
) )
self.data_iterator_checkpoints: Optional[list[SerializableStateType]] = None self.data_iterator_checkpoints: Optional[list[SerializableStateType]] = None
self.last_loss: Optional[float] = None self.large_value_counts: dict[str, int] = {}
self.max_values: dict[str, float] = {}
self.saved_results: dict[Call, Any] = {} self.saved_results: dict[Call, Any] = {}
self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats()) self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats())
logger.warning(f"RerunStateMachine initialized in mode {mode}") if _safe_get_rank() == 0:
logger.warning(f"RerunStateMachine initialized in mode {mode}")
def set_mode(self, mode: RerunMode) -> None: def set_mode(self, mode: RerunMode) -> None:
"""Method to set the operating mode""" """Method to set the operating mode"""
logger.warning(f"Setting RerunStateMachine mode {mode}") if _safe_get_rank() == 0:
logger.warning(f"Setting RerunStateMachine mode {mode}")
self.mode = mode self.mode = mode
def get_mode(self) -> RerunMode: def get_mode(self) -> RerunMode:
...@@ -216,9 +250,7 @@ class RerunStateMachine: ...@@ -216,9 +250,7 @@ class RerunStateMachine:
return self.mode return self.mode
def should_run_forward_backward( def should_run_forward_backward(self, data_iterator: DataIteratorArgType) -> bool:
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> bool:
"""Method instructing whether to (re)run the forward-backward pass. """Method instructing whether to (re)run the forward-backward pass.
Args: Args:
...@@ -243,30 +275,20 @@ class RerunStateMachine: ...@@ -243,30 +275,20 @@ class RerunStateMachine:
self.validation_counts = defaultdict(int) self.validation_counts = defaultdict(int)
data_iterators: list[RerunDataIterator] = [] data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
if self.mode != RerunMode.DISABLED and data_iterator is not None:
if not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
# Are we about to start the initial run? # Are we about to start the initial run?
if self.state == RerunState.NOT_RUNNING_YET: if self.state == RerunState.NOT_RUNNING_YET:
if self.mode == RerunMode.DISABLED: if self.mode == RerunMode.DISABLED:
self.state = RerunState.INITIAL_RUN self.state = RerunState.INITIAL_RUN
self.current_iteration += 1 # Increment self.current_iteration for reporting.
return True return True
if self.data_iterator_checkpoints is not None: if self.data_iterator_checkpoints is not None:
assert ( assert len(self.data_iterator_checkpoints) == len(
len(self.data_iterator_checkpoints) == len(data_iterators), data_iterators
"data_iterator has different length than checkpointed data iterator", ), "data iterator has different length than checkpointed data iterator"
)
for i, d in enumerate(data_iterators): for i, d in enumerate(data_iterators):
d.set_checkpoint_state(self.data_iterator_checkpoints[i]) d.load_state_dict(self.data_iterator_checkpoints[i])
self.data_iterator_checkpoints = None self.data_iterator_checkpoints = None
self._save_state() self._save_state()
if data_iterators: if data_iterators:
...@@ -464,10 +486,28 @@ class RerunStateMachine: ...@@ -464,10 +486,28 @@ class RerunStateMachine:
verifying the result is the same. verifying the result is the same.
""" """
# Skip the validation check if the state machine is disabled or if we haven't run # If reruns are disabled, still validate the result and throw a RuntimeError if it is
# a full iteration yet. We cannot guarantee that a checkpoint can be taken before the # rejected. This is a backward-compatible behavior.
# optimizer has been stepped at least once. if self.mode == RerunMode.DISABLED:
if self.mode == RerunMode.DISABLED or self.current_iteration < 1: result_rejected: bool = rejection_func(result)
if result_rejected:
self._log_validation_error_to_file(
status=RerunValidationStatus.RERUN_DISABLED, result=result, message=message
)
rank: int = _safe_get_rank()
node: str = os.uname()[1]
device: int = torch.cuda.current_device()
full_message: str = (
f"Rank {rank}, node {node}, device {device}, "
f"iteration {self.current_iteration}: "
f"Unexpected result {result} (message='{message}')"
)
raise RuntimeError(full_message)
return
# Skip the validation on the first iteration, as we cannot guarantee a checkpoint can be
# taken before the optimizer has been stepped at least once.
if self.current_iteration < 1:
return return
if comparison_func is None: if comparison_func is None:
...@@ -523,6 +563,9 @@ class RerunStateMachine: ...@@ -523,6 +563,9 @@ class RerunStateMachine:
self.failed_validation_call = validation_call self.failed_validation_call = validation_call
self.initial_result = result self.initial_result = result
self.rerun_requested = True self.rerun_requested = True
self._log_validation_error_to_file(
status=RerunValidationStatus.INITIAL_RUN, result=result, message=message
)
logger.error( logger.error(
f"Unexpected result {result} at {validation_call.caller.filename} " f"Unexpected result {result} at {validation_call.caller.filename} "
f"line {validation_call.caller.lineno}, " f"line {validation_call.caller.lineno}, "
...@@ -547,6 +590,11 @@ class RerunStateMachine: ...@@ -547,6 +590,11 @@ class RerunStateMachine:
"First rerun: unexpected result is not reproducible within the tolerance " "First rerun: unexpected result is not reproducible within the tolerance "
f"({result} != {self.initial_result})" f"({result} != {self.initial_result})"
) )
self._log_validation_error_to_file(
status=RerunValidationStatus.FIRST_RERUN_NOT_REPRODUCIBLE,
result=result,
message=message,
)
log_failure("Possible transient error!") log_failure("Possible transient error!")
else: else:
self.checkpoint_requested = True self.checkpoint_requested = True
...@@ -554,6 +602,11 @@ class RerunStateMachine: ...@@ -554,6 +602,11 @@ class RerunStateMachine:
# rerunning on the same GPU when we resume from the checkpoint. # rerunning on the same GPU when we resume from the checkpoint.
self.suspicious_node = os.uname()[1] self.suspicious_node = os.uname()[1]
self.suspicious_device = torch.cuda.current_device() self.suspicious_device = torch.cuda.current_device()
self._log_validation_error_to_file(
status=RerunValidationStatus.FIRST_RERUN_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning( logger.warning(
"First rerun: unexpected result is reproducible within the tolerance " "First rerun: unexpected result is reproducible within the tolerance "
f"({result} = {self.initial_result}). " f"({result} = {self.initial_result}). "
...@@ -571,12 +624,22 @@ class RerunStateMachine: ...@@ -571,12 +624,22 @@ class RerunStateMachine:
) )
self.restart_again_requested = True self.restart_again_requested = True
elif comparison > tolerance: elif comparison > tolerance:
self._log_validation_error_to_file(
status=RerunValidationStatus.SECOND_RERUN_NOT_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning( logger.warning(
"Second rerun: unexpected result is not reproducible on a different GPU, " "Second rerun: unexpected result is not reproducible on a different GPU, "
f"therefore was likely incorrect ({result} != {self.initial_result})" f"therefore was likely incorrect ({result} != {self.initial_result})"
) )
log_failure("Possible persistent error!") log_failure("Possible persistent error!")
else: else:
self._log_validation_error_to_file(
status=RerunValidationStatus.SECOND_RERUN_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning( logger.warning(
"Second rerun: unexpected result is reproducible on a different GPU, " "Second rerun: unexpected result is reproducible on a different GPU, "
f"therefore it was likely correct ({result} = {self.initial_result})" f"therefore it was likely correct ({result} = {self.initial_result})"
...@@ -587,13 +650,31 @@ class RerunStateMachine: ...@@ -587,13 +650,31 @@ class RerunStateMachine:
else: else:
raise RuntimeError("Should not be here") raise RuntimeError("Should not be here")
def is_spiky_loss(self, loss_tensor: torch.Tensor, threshold: float) -> bool: def is_unexpectedly_large(
"""Helper method to estimate whether a loss is spiky. self,
result: torch.Tensor,
threshold: float,
context: str,
num_samples: int = 100,
resample: bool = False,
) -> bool:
"""Helper method to estimate whether a result is unexpectedly large.
Some calculation errors manifest themselves as results with unexpectedly large
exponents, e.g. spiky loss or grads. This method keeps track of a value over time
and flags it if it exceeds a certain threshold expressed as a multiple factor of
the max value observed.
Args: Args:
loss_tensor: a zero-dim tensor containing the current loss. loss_tensor: a zero-dim tensor containing the current loss.
threshold: a float representing the minimum relative variation threshold: a float representing the minimum trigger threshold
characterizing a spiky loss (e.g. 0.1 means +/- 10%). e.g. 10 means > 10x max absolute value observed.
context: a string identifying the value. This is used to differentiate
between different invokations of validate_results targetting different
values, e.g. loss and grads.
num_samples: the sample size used to estimate the max value.
Default is 100 value samples.
reset: whether to resample the max value. Default is False.
Returns: Returns:
A boolean telling whether the current loss deviates from the previous A boolean telling whether the current loss deviates from the previous
loss by a factor greater than the threshold loss by a factor greater than the threshold
...@@ -612,37 +693,42 @@ class RerunStateMachine: ...@@ -612,37 +693,42 @@ class RerunStateMachine:
loss = loss_fn(outputs) loss = loss_fn(outputs)
rerun_machine.validate_result( rerun_machine.validate_result(
result=loss, result=loss,
rejection_func=partial(rerun_machine.is_spiky_loss, threshold=0.1), rejection_func=partial(
rerun_machine.is_unexpectedly_large,
threshold=10,
context="loss",
),
message="Spiky loss", message="Spiky loss",
tolerance=0.0, tolerance=0.0,
fatal=False, fatal=False,
) )
""" """
loss: float = loss_tensor.item() value: float = math.fabs(result.item())
result: bool = False # Ignore NaNs and Infs. They should be checked separately.
if self.last_loss is not None: if math.isnan(value) or math.isinf(value):
# Ignore NaNs, and consider infinite loss as spiky. return False
if math.isnan(loss) or math.isnan(self.last_loss):
result = False if resample or context not in self.large_value_counts:
elif math.isinf(loss) or math.isinf(self.last_loss): self.large_value_counts[context] = 0
result = True if self.large_value_counts[context] < num_samples:
else: self.large_value_counts[context] += 1
result = math.fabs(loss - self.last_loss) / self.last_loss >= threshold self.max_values[context] = max(self.max_values.get(context, 0.0), value)
self.last_loss = loss if self.large_value_counts[context] == num_samples:
return result logger.warning(f"Max value for {context}: {self.max_values[context]}")
return False
return value >= self.max_values[context] * threshold
def get_checkpoint_state( def state_dict(self, data_iterator: DataIteratorArgType, use_dist_ckpt: bool) -> dict[str, Any]:
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> list[dict[str, Any]]:
"""Method that returns a state dict to be checkpointed. """Method that returns a state dict to be checkpointed.
Args: Args:
data_iterator: the data iterator that needs to be checkpointed (or None data_iterator: the data iterator that needs to be checkpointed (or None
if this checkpoint is not requested by the rerun state machine). if this checkpoint is not requested by the rerun state machine).
use_dist_ckpt: generate a distributed checkpoint.
Returns: Returns:
A list of state dicts, each state dict representing the rerun state machine A state dict representing the rerun state machine.
for one rank.
Example usage: Example usage:
...@@ -651,65 +737,63 @@ class RerunStateMachine: ...@@ -651,65 +737,63 @@ class RerunStateMachine:
... ...
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
checkpoint['rerun_state_machine'] = ( checkpoint['rerun_state_machine'] = (
rerun_state_machine.get_checkpoint_state(data_iterator) rerun_state_machine.state_dict(data_iterator, False)
) )
... ...
return checkpoint return checkpoint
""" """
data_iterators: list[RerunDataIterator] data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif isinstance(data_iterator, (list, tuple)):
data_iterators = data_iterator
else:
data_iterators = [data_iterator] if data_iterator is not None else []
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
state: dict[str, Any] = { # The RerunStateMachine state is different across all ranks. Therefore it needs to be
# checkpointed using a ShardedObject. However, we keep the common state in the non-sharded
# (common) checkpoint. This allows us to verify whether a checkpoint contains a
# RerunStateMachine state by checking the common checkpoint.
state_dict: dict[str, Any] = {
'mode': self.mode, 'mode': self.mode,
'state': self.state, 'sharded': {
'current_iteration': self.current_iteration, 'state': self.state,
'rerun_requested': self.rerun_requested, 'current_iteration': self.current_iteration,
'checkpoint_requested': self.checkpoint_requested, 'rerun_requested': self.rerun_requested,
'restart_again_requested': self.restart_again_requested, 'checkpoint_requested': self.checkpoint_requested,
'continue_requested': self.continue_requested, 'restart_again_requested': self.restart_again_requested,
# logged_sdc_enabled should not be saved (set at the job startup time). 'continue_requested': self.continue_requested,
'error_injector_checkpoint': self.error_injector.get_checkpoint_state(), # logged_sdc_enabled should not be saved (set at the job startup time).
# validation_counts should not be saved (reset at the beginning of the training loop). 'error_injector_checkpoint': self.error_injector.state_dict(),
'failed_validation_call': self.failed_validation_call, # validation_counts should not be saved (reset at start of training loop).
'initial_result': self.initial_result, 'failed_validation_call': self.failed_validation_call,
'suspicious_node': self.suspicious_node, 'initial_result': self.initial_result,
'suspicious_device': self.suspicious_device, 'suspicious_node': self.suspicious_node,
# No need to save saved_state (RNG state already captured in checkpoint). 'suspicious_device': self.suspicious_device,
'data_iterator_checkpoints': ( # No need to save saved_state (RNG state already captured in checkpoint).
[d.get_checkpoint_state() for d in data_iterators] if data_iterators else None 'data_iterator_checkpoints': (
), [d.state_dict() for d in data_iterators] if data_iterators else None
'last_loss': self.last_loss, ),
# No need to save saved_results and stats (resets when job resumes). 'large_value_counts': self.large_value_counts,
'max_values': self.max_values,
# No need to save saved_results and stats (resets when job resumes).
},
} }
state_list: list[dict[str, Any]] if use_dist_ckpt:
if ( pp_rank = mpu.get_pipeline_model_parallel_rank()
torch.distributed.is_initialized() pp_size = mpu.get_pipeline_model_parallel_world_size()
and torch.distributed.get_world_size() > 1 tp_rank = mpu.get_tensor_model_parallel_rank()
and self.mode != RerunMode.DISABLED tp_size = mpu.get_tensor_model_parallel_world_size()
): state_dict['sharded'] = ShardedObject(
state_list = [None for i in range(torch.distributed.get_world_size())] 'rerun_state_machine_state',
torch.distributed.all_gather_object(state_list, state) state_dict['sharded'],
else: (pp_size, tp_size),
state_list = [state] (pp_rank, tp_rank),
return state_list replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
)
return state_dict
def set_checkpoint_state(self, state_list: list[dict[str, Any]]) -> None: def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Method that restores the state from a checkpoint. """Method that restores the state from a checkpoint.
Args: Args:
state_list: the list of state dicts saved in the checkpoint and originally state_dict: the state dict saved in the checkpoint and originally
obtained from get_checkpoint_state(). obtained from state_dict().
Returns: Returns:
None None
...@@ -719,31 +803,59 @@ class RerunStateMachine: ...@@ -719,31 +803,59 @@ class RerunStateMachine:
... ...
if 'rerun_state_machine' in checkpoint: if 'rerun_state_machine' in checkpoint:
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
rerun_state_machine.set_checkpoint_state(checkpoint['rerun_state_machine']) rerun_state_machine.load_state_dict(checkpoint['rerun_state_machine'])
""" """
if self.mode == RerunMode.DISABLED: if self.mode == RerunMode.DISABLED:
if _safe_get_rank() == 0:
logger.warning(
"RerunStateMachine disabled via CLI, ignoring machine state saved in checkpoint"
)
return return
rank: int = _safe_get_rank() if state_dict['mode'] == RerunMode.DISABLED:
if rank == 0: if _safe_get_rank() == 0:
logger.warning(
"RerunStateMachine disabled in checkpoint but enabled via CLI, "
"ignoring machine state saved in checkpoint"
)
return
if _safe_get_rank() == 0:
logger.warning( logger.warning(
"Getting RerunStaeMachine state from checkpoint, args rerun options ignored" "Getting RerunStateMachine state from checkpoint, CLI rerun args ignored"
) )
state = state_list[rank] self.mode = state_dict['mode']
self.mode = state['mode'] sharded_dict = state_dict['sharded']
self.state = state['state'] self.state = sharded_dict['state']
self.current_iteration = state['current_iteration'] self.current_iteration = sharded_dict['current_iteration']
self.rerun_requested = state['rerun_requested'] self.rerun_requested = sharded_dict['rerun_requested']
self.checkpoint_requested = state['checkpoint_requested'] self.checkpoint_requested = sharded_dict['checkpoint_requested']
self.restart_again_requested = state['restart_again_requested'] self.restart_again_requested = sharded_dict['restart_again_requested']
self.continue_requested = state['continue_requested'] self.continue_requested = sharded_dict['continue_requested']
self.error_injector.set_checkpoint_state(state['error_injector_checkpoint']) self.error_injector.load_state_dict(sharded_dict['error_injector_checkpoint'])
self.failed_validation_call = state['failed_validation_call'] self.failed_validation_call = sharded_dict['failed_validation_call']
self.initial_result = state['initial_result'] self.initial_result = sharded_dict['initial_result']
self.suspicious_node = state['suspicious_node'] self.suspicious_node = sharded_dict['suspicious_node']
self.suspicious_device = state['suspicious_device'] self.suspicious_device = sharded_dict['suspicious_device']
self.data_iterator_checkpoints = state['data_iterator_checkpoints'] self.data_iterator_checkpoints = sharded_dict['data_iterator_checkpoints']
self.last_loss = state['last_loss'] self.large_value_counts = sharded_dict['large_value_counts']
self.max_values = sharded_dict['max_values']
def _sanitize_data_iterators(
self, data_iterator: DataIteratorArgType
) -> list["RerunDataIterator"]:
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
data_iterators = [d for d in data_iterators if d is not None]
for d in data_iterators:
assert isinstance(
d, RerunDataIterator
), "data iterator is not wrapped with RerunDataIterator"
return data_iterators
def _get_validation_call_info(self) -> Call: def _get_validation_call_info(self) -> Call:
"""Internal method to get the context about the caller to validate_result().""" """Internal method to get the context about the caller to validate_result()."""
...@@ -817,6 +929,64 @@ class RerunStateMachine: ...@@ -817,6 +929,64 @@ class RerunStateMachine:
logger.info(f" From {caller.filename}, line {caller.lineno}:") logger.info(f" From {caller.filename}, line {caller.lineno}:")
logger.info(f" {stats.print_stats()}") logger.info(f" {stats.print_stats()}")
def _log_validation_error_to_file(
self, status: RerunValidationStatus, result: Any, message: str
) -> None:
if self.result_rejected_tracker_filename is not None:
# Append to log.
try:
rank: int = _safe_get_rank()
node: str = os.uname()[1]
device: int = torch.cuda.current_device()
with open(self.result_rejected_tracker_filename, 'a') as f:
print(
f"ts={datetime.datetime.now()} node={node} device={device} "
f"jobID={os.getenv('SLURM_JOBID', 'N/A')} rank={rank} "
f"iteration={self.current_iteration} status={status} result={result} "
f"message='{message}'",
file=f,
)
except Exception as e:
logger.error(f"Could not log validation error! ({e})")
@classmethod
def get_skipped_iterations_from_tracker_file(cls, tracker_file_name: str) -> list[int]:
"""Get list of iterations to skip from results recorded in tracker file. If an
"abnormality" (e.g., NaN or infinity in gradient) is seen more than once on a
given rank and iteration, the corresponding iteration is skipped.
Args:
tracker_file_name (str): Name of tracker file.
Returns:
list[int]: List of iterations to skip.
"""
iterations_to_skip: set[int] = set()
seen: set[Tuple[int, int]]
regex = r"ts=.+ node=.+ device=.+ jobID=.+ rank=(.+) iteration=(.+) status=(.+) .+"
try:
with open(tracker_file_name, 'r') as f:
for line in f.readlines():
match = re.search(regex, line)
if match:
rank = int(match[1])
iteration = int(match[2])
status = match[3]
# Skip an iteration if:
# - Reruns were disabled and it has failed on the same rank twice.
# or
# - Reruns were enabled and it was reproducible on the 2nd rerun
if status == RerunValidationStatus.RERUN_DISABLED:
if (rank, iteration) in seen:
iterations_to_skip.add(iteration)
else:
seen.add((rank, iteration))
elif status == RerunValidationStatus.SECOND_RERUN_REPRODUCIBLE:
iterations_to_skip.add(iteration)
except Exception as e:
logger.error(f"Could not parse iterations to skip in tracker file! ({e})")
return sorted(iterations_to_skip)
class RerunDataIterator: class RerunDataIterator:
"""A wrapper class for data iterators that adds replay capability. """A wrapper class for data iterators that adds replay capability.
...@@ -837,8 +1007,8 @@ class RerunDataIterator: ...@@ -837,8 +1007,8 @@ class RerunDataIterator:
replay_data_iterator = RerunDataIterator(data_iterator) replay_data_iterator = RerunDataIterator(data_iterator)
""" """
def __init__(self, iterable: Any, make_iterable: bool = True) -> None: def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable self.iterable: Iterable[Any] = iterable
self.saved_microbatches: list[Any] = [] self.saved_microbatches: list[Any] = []
self.replaying: bool = False self.replaying: bool = False
self.replay_pos: int = 0 self.replay_pos: int = 0
...@@ -870,7 +1040,7 @@ class RerunDataIterator: ...@@ -870,7 +1040,7 @@ class RerunDataIterator:
self.replaying = False self.replaying = False
self.saved_microbatches = [] self.saved_microbatches = []
def get_checkpoint_state(self) -> SerializableStateType: def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the iterator as a serializable dict.""" """Method to capture the state of the iterator as a serializable dict."""
return { return {
...@@ -879,7 +1049,7 @@ class RerunDataIterator: ...@@ -879,7 +1049,7 @@ class RerunDataIterator:
'replay_pos': self.replay_pos, 'replay_pos': self.replay_pos,
} }
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None: def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict.""" """Method to restore the state saved as a serializable dict."""
self.saved_microbatches = state_dict['saved_microbatches'] self.saved_microbatches = state_dict['saved_microbatches']
...@@ -1051,7 +1221,7 @@ class RerunErrorInjector: ...@@ -1051,7 +1221,7 @@ class RerunErrorInjector:
else: else:
raise RuntimeError("Should not be here") raise RuntimeError("Should not be here")
def get_checkpoint_state(self) -> SerializableStateType: def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the error injector as a serializable dict.""" """Method to capture the state of the error injector as a serializable dict."""
return { return {
...@@ -1061,7 +1231,7 @@ class RerunErrorInjector: ...@@ -1061,7 +1231,7 @@ class RerunErrorInjector:
'injected_error_type': self.injected_error_type, 'injected_error_type': self.injected_error_type,
} }
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None: def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict.""" """Method to restore the state saved as a serializable dict."""
self.error_injection_rate = state_dict['error_injection_rate'] self.error_injection_rate = state_dict['error_injection_rate']
...@@ -1107,7 +1277,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None: ...@@ -1107,7 +1277,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None:
def _safe_get_rank() -> int: def _safe_get_rank() -> int:
"""Internal function that safely checks and returns the rank of the caller.""" """Internal function that safely checks and returns the rank of the caller."""
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 if torch.distributed.is_initialized():
return torch.distributed.get_rank()
# If torch.distributed is not initialized, try to read environment variables.
try:
return int(os.environ.get("RANK", 0))
except (ValueError, TypeError):
return 0
def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float: def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float:
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
# This source code is licensed under the Apache license found in the # This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Union from typing import Dict, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.spec_utils import ModuleSpec, build_module
...@@ -37,6 +39,9 @@ class MambaLayerSubmodules: ...@@ -37,6 +39,9 @@ class MambaLayerSubmodules:
mixer: Union[ModuleSpec, type] = IdentityOp mixer: Union[ModuleSpec, type] = IdentityOp
mamba_bda: Union[ModuleSpec, type] = IdentityOp mamba_bda: Union[ModuleSpec, type] = IdentityOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class MambaLayer(MegatronModule): class MambaLayer(MegatronModule):
""" """
...@@ -57,6 +62,7 @@ class MambaLayer(MegatronModule): ...@@ -57,6 +62,7 @@ class MambaLayer(MegatronModule):
"""Initialize Mamba Layer.""" """Initialize Mamba Layer."""
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.submodules_config = submodules
self.layer_number = layer_number self.layer_number = layer_number
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
...@@ -114,3 +120,26 @@ class MambaLayer(MegatronModule): ...@@ -114,3 +120,26 @@ class MambaLayer(MegatronModule):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
"""Allocate the inference cache.""" """Allocate the inference cache."""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype) return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the mamba layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the mamba layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -985,6 +985,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -985,6 +985,14 @@ class ColumnParallelLinear(torch.nn.Module):
"""Keep compatibility with TE state dict.""" """Keep compatibility with TE state dict."""
return None return None
def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
...@@ -1206,3 +1214,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -1206,3 +1214,11 @@ class RowParallelLinear(torch.nn.Module):
def get_extra_state(self) -> None: def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict.""" """Keep compatibility with TE state dict."""
return None return None
def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
File mode changed from 100755 to 100644
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
import contextlib import contextlib
import logging import logging
from functools import partial
from typing import Union
import torch import torch
from torch import _C from torch import _C
from torch.cuda import _lazy_call from torch.cuda import _lazy_call, _lazy_init
from torch.cuda import device as device_ctx_manager from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
...@@ -21,17 +23,59 @@ from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_da ...@@ -21,17 +23,59 @@ from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_da
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
try:
import transformer_engine # pylint: disable=unused-import
HAVE_TE = True
except ModuleNotFoundError:
HAVE_TE = False
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng' _EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng' _DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1): def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda", clone: bool = False, graph_safe: bool = False
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU.
Arguments:
device (int): The gpu to retrieve the rng state
clone (bool): Whether to also clone the retrieved RNG state
graph_safe (bool): Get the rng state in a graph safe manner.
This function is adapted from torch.cuda.random.get_rng_state()"""
# if not using cuda graphs, just use the builtin pytorch function
if not graph_safe:
return torch.cuda.random.get_rng_state(device=device)
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if clone:
return default_generator.clone_state()
return default_generator.graphsafe_get_state()
def _set_cuda_rng_state(new_state: torch.Tensor, device: int = -1, graph_safe: bool = False):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
Argumentss: Arguments:
new_state (torch.ByteTensor): The desired state new_state (torch.ByteTensor): The desired state
device (int): The gpu to retrieve the rng state
graph_safe (bool): Set the rng state in a graph safe manner.
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases. major performance issues for +4 GPU cases.
...@@ -56,7 +100,12 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -56,7 +100,12 @@ def _set_cuda_rng_state(new_state, device=-1):
if idx is None: if idx is None:
idx = torch.cuda.current_device() idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx] default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
# if graph capturing, set the rng state in a cudagraphable way
if graph_safe:
default_generator.graphsafe_set_state(new_state)
else:
default_generator.set_state(new_state)
_lazy_call(cb) _lazy_call(cb)
...@@ -82,8 +131,17 @@ class CudaRNGStatesTracker: ...@@ -82,8 +131,17 @@ class CudaRNGStatesTracker:
cuda state. cuda state.
""" """
def __init__(self): def __init__(self, use_cudagraphable_rng=False):
self.reset() self.reset()
self.use_cudagraphable_rng = use_cudagraphable_rng
if self.use_cudagraphable_rng:
assert (
hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state")
), "Tried using cudagraphs with RNG, however not detected in pytorch!"
def is_initialized(self): def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states().""" """Checks if the internal RNG state has been set wirth set_states()."""
...@@ -125,13 +183,20 @@ class CudaRNGStatesTracker: ...@@ -125,13 +183,20 @@ class CudaRNGStatesTracker:
# Check that state is not already defined. # Check that state is not already defined.
if name in self.states_: if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name)) raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state() # If available, create the state in a graph safe manner
# Set the new state and store it. if self.use_cudagraphable_rng:
torch.cuda.manual_seed(seed) new_state = _get_cuda_rng_state(clone=True, graph_safe=True)
self.states_[name] = torch.cuda.get_rng_state() new_state.manual_seed(seed)
# Reset rng state to what it was. self.states_[name] = new_state
_set_cuda_rng_state(orig_rng_state) else:
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager @contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
...@@ -141,9 +206,9 @@ class CudaRNGStatesTracker: ...@@ -141,9 +206,9 @@ class CudaRNGStatesTracker:
if name not in self.states_: if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name)) raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state. # Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state() orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# Set rng state to the desired one # Set rng state to the desired one
_set_cuda_rng_state(self.states_[name]) _set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng)
# Record cpu RNG state # Record cpu RNG state
cpu_rng_state = torch.get_rng_state() cpu_rng_state = torch.get_rng_state()
# Do the stuff we wanted to do. # Do the stuff we wanted to do.
...@@ -154,9 +219,9 @@ class CudaRNGStatesTracker: ...@@ -154,9 +219,9 @@ class CudaRNGStatesTracker:
if not torch.all(cpu_rng_state == torch.get_rng_state()).item(): if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context') logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
# Update the current rng state for later use. # Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state() self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# And set the state to the original state we started with. # And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state) _set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng)
# RNG tracker object. # RNG tracker object.
...@@ -164,35 +229,85 @@ _CUDA_RNG_STATE_TRACKER = None ...@@ -164,35 +229,85 @@ _CUDA_RNG_STATE_TRACKER = None
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False _CUDA_RNG_STATE_TRACKER_INITIALIZED = False
def initialize_rng_tracker(use_te_rng_tracker: bool = False): def initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Create the RNG tracker. 'use_te_rng_tracker' determines whether to use """Create the RNG tracker. 'use_te_rng_tracker' determines whether to use
Megatron or TransformerEngine's implementation. Megatron or TransformerEngine's implementation.
In particular, TransformerEngine's implementation is cudagraphable and supports FP8. In particular, TransformerEngine's implementation is cudagraphable and supports FP8.
""" """
global _CUDA_RNG_STATE_TRACKER global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED: if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return return
if use_te_rng_tracker: # Get the base tracker class
base_tracker = None
if HAVE_TE and use_te_rng_tracker:
if not is_te_min_version("1.5.0"): if not is_te_min_version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5") raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
_CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker() base_tracker = TECudaRNGStatesTracker
else: else:
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() base_tracker = partial(CudaRNGStatesTracker, use_cudagraphable_rng=use_cudagraphable_rng)
if inference_rng_tracker:
class InferenceCudaRNGStatesTracker(base_tracker):
"""RNG tracker for inference."""
def add(self, name, seed):
"""Mirrors the interface from the training RNG tracker."""
pass
def set_states(self, states):
"""Mirrors the interface from the training RNG tracker."""
pass
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Mirrors the interface from the training RNG tracker."""
return contextlib.nullcontext()
tracker_class = InferenceCudaRNGStatesTracker
else:
tracker_class = base_tracker
_CUDA_RNG_STATE_TRACKER = tracker_class()
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True _CUDA_RNG_STATE_TRACKER_INITIALIZED = True
def get_cuda_rng_tracker(use_te_rng_tracker=False): def get_cuda_rng_tracker(use_te_rng_tracker=False, inference_rng_tracker=False):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
initialize_rng_tracker(use_te_rng_tracker) initialize_rng_tracker(use_te_rng_tracker, inference_rng_tracker)
return _CUDA_RNG_STATE_TRACKER return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed): def get_all_rng_states() -> bool:
"""Returns all generator states used by the current `CudaRNGStatesTracker`."""
assert (
_CUDA_RNG_STATE_TRACKER_INITIALIZED
), "Tried getting all rng states but RNG Tracker has not been initalized!"
if isinstance(_CUDA_RNG_STATE_TRACKER, CudaRNGStatesTracker):
return _CUDA_RNG_STATE_TRACKER.states_
# If TE is installed, check if we are using TE's RNG tracker
elif HAVE_TE and is_te_min_version("1.5.0"):
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
if isinstance(_CUDA_RNG_STATE_TRACKER, TECudaRNGStatesTracker):
from transformer_engine.pytorch.distributed import get_all_rng_states
return get_all_rng_states()
# no valid tracker, return an empty dict
else:
return {}
def model_parallel_cuda_manual_seed(seed, te_rng_tracker=False, inference_rng_tracker=False):
"""Initialize model parallel cuda seed. """Initialize model parallel cuda seed.
This function should be called after the model parallel is This function should be called after the model parallel is
...@@ -216,7 +331,7 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -216,7 +331,7 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
initialize_rng_tracker() initialize_rng_tracker(te_rng_tracker, inference_rng_tracker)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
...@@ -239,14 +354,16 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -239,14 +354,16 @@ class CheckpointFunction(torch.autograd.Function):
2) the states in the model parallel tracker are also properly tracked/set/reset. 2) the states in the model parallel tracker are also properly tracked/set/reset.
""" """
# pylint: disable=missing-function-docstring
@staticmethod @staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args): def forward(ctx, run_function, distribute_saved_activations, *args):
"""Forward pass."""
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = _get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad(): with torch.no_grad():
...@@ -265,8 +382,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -265,8 +382,10 @@ class CheckpointFunction(torch.autograd.Function):
return outputs return outputs
# pylint: disable=missing-function-docstring
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError( raise RuntimeError(
"Checkpointing is not compatible with .grad(), " "Checkpointing is not compatible with .grad(), "
...@@ -280,7 +399,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -280,7 +399,7 @@ class CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state = _get_cuda_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
...@@ -302,7 +421,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -302,7 +421,9 @@ class CheckpointFunction(torch.autograd.Function):
outputs = (outputs,) outputs = (outputs,)
# filter out non tensor outputs for backward pass # filter out non tensor outputs for backward pass
outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args))) outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args))
)
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads return (None, None) + grads
......
File mode changed from 100755 to 100644
...@@ -24,12 +24,20 @@ class TimerBase(ABC): ...@@ -24,12 +24,20 @@ class TimerBase(ABC):
@abstractmethod @abstractmethod
def start(self, barrier=False): def start(self, barrier=False):
"""Start the timer.""" """Start the timer.
Args:
barrier (bool, optional): Synchronizes ranks before starting. Defaults to False.
"""
pass pass
@abstractmethod @abstractmethod
def stop(self, barrier=False): def stop(self, barrier=False):
"""Stop the timer.""" """Stop the timer.
Args:
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
"""
pass pass
@abstractmethod @abstractmethod
...@@ -39,7 +47,15 @@ class TimerBase(ABC): ...@@ -39,7 +47,15 @@ class TimerBase(ABC):
@abstractmethod @abstractmethod
def elapsed(self, reset=True, barrier=False): def elapsed(self, reset=True, barrier=False):
"""Calculates the elapsed time.""" """Calculates the elapsed time and restarts timer.
Args:
reset (bool, optional): Resets timer before restarting. Defaults to True.
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
Returns:
float: Elapsed time.
"""
pass pass
...@@ -59,7 +75,19 @@ class DummyTimer(TimerBase): ...@@ -59,7 +75,19 @@ class DummyTimer(TimerBase):
return return
def elapsed(self, reset=True, barrier=False): def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to calculate elapsed time') raise Exception(
'dummy timer should not be used to calculate elapsed time, '
'check if timer\'s log_level <= self._log_level.'
)
def active_time(self):
"""Returns the cumulative duration the timer has been active.
Note: Not supported for DummyTimer.
"""
raise Exception(
'active timer should not be used to calculate elapsed time, '
'check if timer\'s log_level <= self._log_level.'
)
class Timer(TimerBase): class Timer(TimerBase):
...@@ -155,7 +183,7 @@ class Timer(TimerBase): ...@@ -155,7 +183,7 @@ class Timer(TimerBase):
return _elapsed return _elapsed
def active_time(self): def active_time(self):
"""Returns the active time.""" """Calculates the cumulative duration for which the timer has been active"""
return self._active_time return self._active_time
...@@ -397,8 +425,8 @@ class Timers: ...@@ -397,8 +425,8 @@ class Timers:
reset: bool = True, reset: bool = True,
barrier: bool = False, barrier: bool = False,
): ):
"""Write timers to a tensorboard writer. Note that we only report maximum time across ranks """Write timers to a tensorboard writer.
to tensorboard. Note that we only report maximum time across ranks to tensorboard.
Args: Args:
names (List[str]): Names of the timers to log. names (List[str]): Names of the timers to log.
......
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment