Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from typing import Callable
import torch
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.nn import QuantModuleRegistry
from modelopt.torch.quantization.nn.modules.quant_linear import _QuantLinear
has_nvidia_modelopt = True
except Exception:
has_nvidia_modelopt = False
class Linear(torch.nn.Linear):
"""Local Linear impl as a replacement of TELinear."""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config
self._return_bias = skip_bias_add and bias
if skip_weight_param_allocation:
raise ValueError('torch.nn.Linear layers do not support skip_weight_param_allocation')
super().__init__(
in_features=input_size, out_features=output_size, bias=bias, dtype=config.params_dtype
)
for param in self.parameters():
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, 'allreduce', self.config.expert_model_parallel_size == 1)
else:
# Reduce the gradient on DP group
setattr(param, 'allreduce', True)
setattr(param, 'sequence_parallel', self.config.sequence_parallel)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
for k, v in state_dict.items():
if "_amax" in k or "_scale" in k:
if v.ndim == 0:
state_dict[k] = v.view(1)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
state_dict, prefix, sharded_offsets=sharded_offsets
)
return sharded_state_dict
def forward(self, x):
"""Forward."""
out = super().forward(x)
if self._return_bias:
return out
return out, None
if has_nvidia_modelopt:
QuantModuleRegistry.register({Linear: Linear.__class__.__name__})(_QuantLinear)
class RealQuantTransformerLayer(TransformerLayer):
"""Real quantization transformer layer base class.
This base class iniitialize the default TransformerLayer and immediately
perform weight-only real quantization via TensorRT Model Optimizer.
All linear weights (Linear, ColumnParallelLinear, RowParallelLinear) picked
up will be replaced with low-bit data type (default torch.uint8). If sub-byte
real_quant_cfg is used, the weight shape will further be half.
This module cannot be trained (all parameters frozen).
"""
verbose: bool = False
real_quant_cfg: str = "None"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if has_nvidia_modelopt and self.real_quant_cfg != "None":
REAL_QUANT_CFG_CHOICES = {
"fp8_real_quant": mtq.FP8_PER_TENSOR_REAL_QUANT_CFG,
"fp8_blockwise_real_quant": mtq.FP8_BLOCKWISE_REAL_QUANT_CFG,
}
mtq_cfg = REAL_QUANT_CFG_CHOICES.get(self.real_quant_cfg, None)
if mtq_cfg is None:
raise ValueError(
"RealQuantTransformerLayer does not support {}".format(self.real_quant_cfg)
)
self._collect_original_tensor_info()
mtq.quantize(self, mtq_cfg)
delattr(self, "_modelopt_state")
# Freeze all parameters since the real-quant linears cannot be trained.
for param in self.parameters():
param.requires_grad = False
if self.verbose:
self._report_quantize_tensor_info()
def _collect_original_tensor_info(self):
self._original_tensor_info = {}
for k, v in self.state_dict().items():
if isinstance(v, torch.Tensor):
self._original_tensor_info[k] = (str(v.dtype), str(v.shape))
def _report_quantize_tensor_info(self):
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
for k, v in self.state_dict().items():
if not isinstance(v, torch.Tensor):
continue
original_dtype, original_shape = self._original_tensor_info.get(k, ("-", "-"))
print(
"{:<64} {:<16} {:<32} {:<16} {:<32}".format(
k, original_dtype, original_shape, str(v.dtype), str(v.shape)
)
)
torch.distributed.barrier()
class FP8WeightTransformerLayer(RealQuantTransformerLayer):
"""FP8 weight transformer layer."""
real_quant_cfg: str = "fp8_real_quant"
class BlockwiseFP8WeightTransformerLayer(RealQuantTransformerLayer):
"""Blockwise FP8 weight transformer layer."""
real_quant_cfg: str = "fp8_blockwise_real_quant"
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_mamba_stack_modelopt_spec(
local_core_attention: bool = False, remap_te_layernorm: bool = False
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map = {}
transformer_state_dict_keys_map = {}
if remap_te_layernorm:
mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'}
transformer_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
mamba_layer = ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
norm=TENorm,
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=ColumnParallelLinear, out_proj=RowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=mamba_state_dict_keys_map,
),
)
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
attention_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
mlp_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer
),
)
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import datetime
import inspect
import logging
import math
import os
import random
import re
from collections import defaultdict
from enum import Enum
from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Union
......@@ -12,6 +14,9 @@ from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Un
import numpy as np
import torch
import megatron.core.parallel_state as mpu
from megatron.core.dist_checkpointing.mapping import ShardedObject
"""DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE.
The rerun state machine implementation in this file is alpha-level code to help
......@@ -34,6 +39,7 @@ EXIT_CODE_RESUME_TO_DISAMBIGUATE: int = 16
EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17
SerializableStateType = Union[list, dict]
DataIteratorArgType = Optional[Union["RerunDataIterator", list["RerunDataIterator"]]]
class Caller(NamedTuple):
......@@ -105,6 +111,17 @@ class RerunState(Enum):
RERUNNING_AGAIN_FROM_CHECKPOINT = 5
class RerunValidationStatus(str, Enum):
"""Enum representing the status of a record in the tracker log file"""
RERUN_DISABLED = 'rerun_disabled'
INITIAL_RUN = 'initial_run'
FIRST_RERUN_NOT_REPRODUCIBLE = 'first_rerun_not_reproducible'
FIRST_RERUN_REPRODUCIBLE = "first_rerun_reproducible"
SECOND_RERUN_NOT_REPRODUCIBLE = "second_rerun_not_reproducible"
SECOND_RERUN_REPRODUCIBLE = "second_rerun_reproducible"
COMPARISON_MATCH: float = 0.0
COMPARISON_MISMATCH: float = math.inf
......@@ -121,6 +138,7 @@ class RerunStateMachine:
state_restore_func: optional function to restore the state saved by state_save_func.
mode: operating mode for the rerun state machine, default is disabled.
error_injector: optional result injection engine, default is no result injection.
result_rejected_tracker_filename: optional name of file tracking `result rejected` events.
Example usage:
......@@ -170,6 +188,7 @@ class RerunStateMachine:
state_restore_func: Optional[Callable[[SerializableStateType], None]] = None,
mode: RerunMode = RerunMode.DISABLED,
error_injector: Optional["RerunErrorInjector"] = None,
result_rejected_tracker_filename: Optional[str] = None,
) -> None:
self.mode: RerunMode = mode
self.state: RerunState = RerunState.NOT_RUNNING_YET
......@@ -192,6 +211,18 @@ class RerunStateMachine:
self.suspicious_node: str = None
self.suspicious_device: int = None
# Keep track of `result_rejected` events.
# Make sure the file can be written to and abort if not.
self.result_rejected_tracker_filename = result_rejected_tracker_filename
if self.result_rejected_tracker_filename is not None:
try:
with open(self.result_rejected_tracker_filename, 'a'):
pass
except Exception as e:
raise RuntimeError(
f"RerunStateMachine result validation log cannot be appended to! ({e})"
)
self.saved_state: Optional[SerializableStateType] = None
self.state_save_func: Optional[Callable[[], SerializableStateType]] = state_save_func
self.state_restore_func: Optional[Callable[[SerializableStateType], None]] = (
......@@ -199,16 +230,19 @@ class RerunStateMachine:
)
self.data_iterator_checkpoints: Optional[list[SerializableStateType]] = None
self.last_loss: Optional[float] = None
self.large_value_counts: dict[str, int] = {}
self.max_values: dict[str, float] = {}
self.saved_results: dict[Call, Any] = {}
self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats())
logger.warning(f"RerunStateMachine initialized in mode {mode}")
if _safe_get_rank() == 0:
logger.warning(f"RerunStateMachine initialized in mode {mode}")
def set_mode(self, mode: RerunMode) -> None:
"""Method to set the operating mode"""
logger.warning(f"Setting RerunStateMachine mode {mode}")
if _safe_get_rank() == 0:
logger.warning(f"Setting RerunStateMachine mode {mode}")
self.mode = mode
def get_mode(self) -> RerunMode:
......@@ -216,9 +250,7 @@ class RerunStateMachine:
return self.mode
def should_run_forward_backward(
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> bool:
def should_run_forward_backward(self, data_iterator: DataIteratorArgType) -> bool:
"""Method instructing whether to (re)run the forward-backward pass.
Args:
......@@ -243,30 +275,20 @@ class RerunStateMachine:
self.validation_counts = defaultdict(int)
data_iterators: list[RerunDataIterator] = []
if self.mode != RerunMode.DISABLED and data_iterator is not None:
if not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
# Are we about to start the initial run?
if self.state == RerunState.NOT_RUNNING_YET:
if self.mode == RerunMode.DISABLED:
self.state = RerunState.INITIAL_RUN
self.current_iteration += 1 # Increment self.current_iteration for reporting.
return True
if self.data_iterator_checkpoints is not None:
assert (
len(self.data_iterator_checkpoints) == len(data_iterators),
"data_iterator has different length than checkpointed data iterator",
)
assert len(self.data_iterator_checkpoints) == len(
data_iterators
), "data iterator has different length than checkpointed data iterator"
for i, d in enumerate(data_iterators):
d.set_checkpoint_state(self.data_iterator_checkpoints[i])
d.load_state_dict(self.data_iterator_checkpoints[i])
self.data_iterator_checkpoints = None
self._save_state()
if data_iterators:
......@@ -464,10 +486,28 @@ class RerunStateMachine:
verifying the result is the same.
"""
# Skip the validation check if the state machine is disabled or if we haven't run
# a full iteration yet. We cannot guarantee that a checkpoint can be taken before the
# optimizer has been stepped at least once.
if self.mode == RerunMode.DISABLED or self.current_iteration < 1:
# If reruns are disabled, still validate the result and throw a RuntimeError if it is
# rejected. This is a backward-compatible behavior.
if self.mode == RerunMode.DISABLED:
result_rejected: bool = rejection_func(result)
if result_rejected:
self._log_validation_error_to_file(
status=RerunValidationStatus.RERUN_DISABLED, result=result, message=message
)
rank: int = _safe_get_rank()
node: str = os.uname()[1]
device: int = torch.cuda.current_device()
full_message: str = (
f"Rank {rank}, node {node}, device {device}, "
f"iteration {self.current_iteration}: "
f"Unexpected result {result} (message='{message}')"
)
raise RuntimeError(full_message)
return
# Skip the validation on the first iteration, as we cannot guarantee a checkpoint can be
# taken before the optimizer has been stepped at least once.
if self.current_iteration < 1:
return
if comparison_func is None:
......@@ -523,6 +563,9 @@ class RerunStateMachine:
self.failed_validation_call = validation_call
self.initial_result = result
self.rerun_requested = True
self._log_validation_error_to_file(
status=RerunValidationStatus.INITIAL_RUN, result=result, message=message
)
logger.error(
f"Unexpected result {result} at {validation_call.caller.filename} "
f"line {validation_call.caller.lineno}, "
......@@ -547,6 +590,11 @@ class RerunStateMachine:
"First rerun: unexpected result is not reproducible within the tolerance "
f"({result} != {self.initial_result})"
)
self._log_validation_error_to_file(
status=RerunValidationStatus.FIRST_RERUN_NOT_REPRODUCIBLE,
result=result,
message=message,
)
log_failure("Possible transient error!")
else:
self.checkpoint_requested = True
......@@ -554,6 +602,11 @@ class RerunStateMachine:
# rerunning on the same GPU when we resume from the checkpoint.
self.suspicious_node = os.uname()[1]
self.suspicious_device = torch.cuda.current_device()
self._log_validation_error_to_file(
status=RerunValidationStatus.FIRST_RERUN_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning(
"First rerun: unexpected result is reproducible within the tolerance "
f"({result} = {self.initial_result}). "
......@@ -571,12 +624,22 @@ class RerunStateMachine:
)
self.restart_again_requested = True
elif comparison > tolerance:
self._log_validation_error_to_file(
status=RerunValidationStatus.SECOND_RERUN_NOT_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning(
"Second rerun: unexpected result is not reproducible on a different GPU, "
f"therefore was likely incorrect ({result} != {self.initial_result})"
)
log_failure("Possible persistent error!")
else:
self._log_validation_error_to_file(
status=RerunValidationStatus.SECOND_RERUN_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning(
"Second rerun: unexpected result is reproducible on a different GPU, "
f"therefore it was likely correct ({result} = {self.initial_result})"
......@@ -587,13 +650,31 @@ class RerunStateMachine:
else:
raise RuntimeError("Should not be here")
def is_spiky_loss(self, loss_tensor: torch.Tensor, threshold: float) -> bool:
"""Helper method to estimate whether a loss is spiky.
def is_unexpectedly_large(
self,
result: torch.Tensor,
threshold: float,
context: str,
num_samples: int = 100,
resample: bool = False,
) -> bool:
"""Helper method to estimate whether a result is unexpectedly large.
Some calculation errors manifest themselves as results with unexpectedly large
exponents, e.g. spiky loss or grads. This method keeps track of a value over time
and flags it if it exceeds a certain threshold expressed as a multiple factor of
the max value observed.
Args:
loss_tensor: a zero-dim tensor containing the current loss.
threshold: a float representing the minimum relative variation
characterizing a spiky loss (e.g. 0.1 means +/- 10%).
threshold: a float representing the minimum trigger threshold
e.g. 10 means > 10x max absolute value observed.
context: a string identifying the value. This is used to differentiate
between different invokations of validate_results targetting different
values, e.g. loss and grads.
num_samples: the sample size used to estimate the max value.
Default is 100 value samples.
reset: whether to resample the max value. Default is False.
Returns:
A boolean telling whether the current loss deviates from the previous
loss by a factor greater than the threshold
......@@ -612,37 +693,42 @@ class RerunStateMachine:
loss = loss_fn(outputs)
rerun_machine.validate_result(
result=loss,
rejection_func=partial(rerun_machine.is_spiky_loss, threshold=0.1),
rejection_func=partial(
rerun_machine.is_unexpectedly_large,
threshold=10,
context="loss",
),
message="Spiky loss",
tolerance=0.0,
fatal=False,
)
"""
loss: float = loss_tensor.item()
result: bool = False
if self.last_loss is not None:
# Ignore NaNs, and consider infinite loss as spiky.
if math.isnan(loss) or math.isnan(self.last_loss):
result = False
elif math.isinf(loss) or math.isinf(self.last_loss):
result = True
else:
result = math.fabs(loss - self.last_loss) / self.last_loss >= threshold
self.last_loss = loss
return result
value: float = math.fabs(result.item())
# Ignore NaNs and Infs. They should be checked separately.
if math.isnan(value) or math.isinf(value):
return False
if resample or context not in self.large_value_counts:
self.large_value_counts[context] = 0
if self.large_value_counts[context] < num_samples:
self.large_value_counts[context] += 1
self.max_values[context] = max(self.max_values.get(context, 0.0), value)
if self.large_value_counts[context] == num_samples:
logger.warning(f"Max value for {context}: {self.max_values[context]}")
return False
return value >= self.max_values[context] * threshold
def get_checkpoint_state(
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> list[dict[str, Any]]:
def state_dict(self, data_iterator: DataIteratorArgType, use_dist_ckpt: bool) -> dict[str, Any]:
"""Method that returns a state dict to be checkpointed.
Args:
data_iterator: the data iterator that needs to be checkpointed (or None
if this checkpoint is not requested by the rerun state machine).
use_dist_ckpt: generate a distributed checkpoint.
Returns:
A list of state dicts, each state dict representing the rerun state machine
for one rank.
A state dict representing the rerun state machine.
Example usage:
......@@ -651,65 +737,63 @@ class RerunStateMachine:
...
rerun_state_machine = get_rerun_state_machine()
checkpoint['rerun_state_machine'] = (
rerun_state_machine.get_checkpoint_state(data_iterator)
rerun_state_machine.state_dict(data_iterator, False)
)
...
return checkpoint
"""
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif isinstance(data_iterator, (list, tuple)):
data_iterators = data_iterator
else:
data_iterators = [data_iterator] if data_iterator is not None else []
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
state: dict[str, Any] = {
# The RerunStateMachine state is different across all ranks. Therefore it needs to be
# checkpointed using a ShardedObject. However, we keep the common state in the non-sharded
# (common) checkpoint. This allows us to verify whether a checkpoint contains a
# RerunStateMachine state by checking the common checkpoint.
state_dict: dict[str, Any] = {
'mode': self.mode,
'state': self.state,
'current_iteration': self.current_iteration,
'rerun_requested': self.rerun_requested,
'checkpoint_requested': self.checkpoint_requested,
'restart_again_requested': self.restart_again_requested,
'continue_requested': self.continue_requested,
# logged_sdc_enabled should not be saved (set at the job startup time).
'error_injector_checkpoint': self.error_injector.get_checkpoint_state(),
# validation_counts should not be saved (reset at the beginning of the training loop).
'failed_validation_call': self.failed_validation_call,
'initial_result': self.initial_result,
'suspicious_node': self.suspicious_node,
'suspicious_device': self.suspicious_device,
# No need to save saved_state (RNG state already captured in checkpoint).
'data_iterator_checkpoints': (
[d.get_checkpoint_state() for d in data_iterators] if data_iterators else None
),
'last_loss': self.last_loss,
# No need to save saved_results and stats (resets when job resumes).
'sharded': {
'state': self.state,
'current_iteration': self.current_iteration,
'rerun_requested': self.rerun_requested,
'checkpoint_requested': self.checkpoint_requested,
'restart_again_requested': self.restart_again_requested,
'continue_requested': self.continue_requested,
# logged_sdc_enabled should not be saved (set at the job startup time).
'error_injector_checkpoint': self.error_injector.state_dict(),
# validation_counts should not be saved (reset at start of training loop).
'failed_validation_call': self.failed_validation_call,
'initial_result': self.initial_result,
'suspicious_node': self.suspicious_node,
'suspicious_device': self.suspicious_device,
# No need to save saved_state (RNG state already captured in checkpoint).
'data_iterator_checkpoints': (
[d.state_dict() for d in data_iterators] if data_iterators else None
),
'large_value_counts': self.large_value_counts,
'max_values': self.max_values,
# No need to save saved_results and stats (resets when job resumes).
},
}
state_list: list[dict[str, Any]]
if (
torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
and self.mode != RerunMode.DISABLED
):
state_list = [None for i in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(state_list, state)
else:
state_list = [state]
return state_list
if use_dist_ckpt:
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
state_dict['sharded'] = ShardedObject(
'rerun_state_machine_state',
state_dict['sharded'],
(pp_size, tp_size),
(pp_rank, tp_rank),
replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
)
return state_dict
def set_checkpoint_state(self, state_list: list[dict[str, Any]]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Method that restores the state from a checkpoint.
Args:
state_list: the list of state dicts saved in the checkpoint and originally
obtained from get_checkpoint_state().
state_dict: the state dict saved in the checkpoint and originally
obtained from state_dict().
Returns:
None
......@@ -719,31 +803,59 @@ class RerunStateMachine:
...
if 'rerun_state_machine' in checkpoint:
rerun_state_machine = get_rerun_state_machine()
rerun_state_machine.set_checkpoint_state(checkpoint['rerun_state_machine'])
rerun_state_machine.load_state_dict(checkpoint['rerun_state_machine'])
"""
if self.mode == RerunMode.DISABLED:
if _safe_get_rank() == 0:
logger.warning(
"RerunStateMachine disabled via CLI, ignoring machine state saved in checkpoint"
)
return
rank: int = _safe_get_rank()
if rank == 0:
if state_dict['mode'] == RerunMode.DISABLED:
if _safe_get_rank() == 0:
logger.warning(
"RerunStateMachine disabled in checkpoint but enabled via CLI, "
"ignoring machine state saved in checkpoint"
)
return
if _safe_get_rank() == 0:
logger.warning(
"Getting RerunStaeMachine state from checkpoint, args rerun options ignored"
"Getting RerunStateMachine state from checkpoint, CLI rerun args ignored"
)
state = state_list[rank]
self.mode = state['mode']
self.state = state['state']
self.current_iteration = state['current_iteration']
self.rerun_requested = state['rerun_requested']
self.checkpoint_requested = state['checkpoint_requested']
self.restart_again_requested = state['restart_again_requested']
self.continue_requested = state['continue_requested']
self.error_injector.set_checkpoint_state(state['error_injector_checkpoint'])
self.failed_validation_call = state['failed_validation_call']
self.initial_result = state['initial_result']
self.suspicious_node = state['suspicious_node']
self.suspicious_device = state['suspicious_device']
self.data_iterator_checkpoints = state['data_iterator_checkpoints']
self.last_loss = state['last_loss']
self.mode = state_dict['mode']
sharded_dict = state_dict['sharded']
self.state = sharded_dict['state']
self.current_iteration = sharded_dict['current_iteration']
self.rerun_requested = sharded_dict['rerun_requested']
self.checkpoint_requested = sharded_dict['checkpoint_requested']
self.restart_again_requested = sharded_dict['restart_again_requested']
self.continue_requested = sharded_dict['continue_requested']
self.error_injector.load_state_dict(sharded_dict['error_injector_checkpoint'])
self.failed_validation_call = sharded_dict['failed_validation_call']
self.initial_result = sharded_dict['initial_result']
self.suspicious_node = sharded_dict['suspicious_node']
self.suspicious_device = sharded_dict['suspicious_device']
self.data_iterator_checkpoints = sharded_dict['data_iterator_checkpoints']
self.large_value_counts = sharded_dict['large_value_counts']
self.max_values = sharded_dict['max_values']
def _sanitize_data_iterators(
self, data_iterator: DataIteratorArgType
) -> list["RerunDataIterator"]:
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
data_iterators = [d for d in data_iterators if d is not None]
for d in data_iterators:
assert isinstance(
d, RerunDataIterator
), "data iterator is not wrapped with RerunDataIterator"
return data_iterators
def _get_validation_call_info(self) -> Call:
"""Internal method to get the context about the caller to validate_result()."""
......@@ -817,6 +929,64 @@ class RerunStateMachine:
logger.info(f" From {caller.filename}, line {caller.lineno}:")
logger.info(f" {stats.print_stats()}")
def _log_validation_error_to_file(
self, status: RerunValidationStatus, result: Any, message: str
) -> None:
if self.result_rejected_tracker_filename is not None:
# Append to log.
try:
rank: int = _safe_get_rank()
node: str = os.uname()[1]
device: int = torch.cuda.current_device()
with open(self.result_rejected_tracker_filename, 'a') as f:
print(
f"ts={datetime.datetime.now()} node={node} device={device} "
f"jobID={os.getenv('SLURM_JOBID', 'N/A')} rank={rank} "
f"iteration={self.current_iteration} status={status} result={result} "
f"message='{message}'",
file=f,
)
except Exception as e:
logger.error(f"Could not log validation error! ({e})")
@classmethod
def get_skipped_iterations_from_tracker_file(cls, tracker_file_name: str) -> list[int]:
"""Get list of iterations to skip from results recorded in tracker file. If an
"abnormality" (e.g., NaN or infinity in gradient) is seen more than once on a
given rank and iteration, the corresponding iteration is skipped.
Args:
tracker_file_name (str): Name of tracker file.
Returns:
list[int]: List of iterations to skip.
"""
iterations_to_skip: set[int] = set()
seen: set[Tuple[int, int]]
regex = r"ts=.+ node=.+ device=.+ jobID=.+ rank=(.+) iteration=(.+) status=(.+) .+"
try:
with open(tracker_file_name, 'r') as f:
for line in f.readlines():
match = re.search(regex, line)
if match:
rank = int(match[1])
iteration = int(match[2])
status = match[3]
# Skip an iteration if:
# - Reruns were disabled and it has failed on the same rank twice.
# or
# - Reruns were enabled and it was reproducible on the 2nd rerun
if status == RerunValidationStatus.RERUN_DISABLED:
if (rank, iteration) in seen:
iterations_to_skip.add(iteration)
else:
seen.add((rank, iteration))
elif status == RerunValidationStatus.SECOND_RERUN_REPRODUCIBLE:
iterations_to_skip.add(iteration)
except Exception as e:
logger.error(f"Could not parse iterations to skip in tracker file! ({e})")
return sorted(iterations_to_skip)
class RerunDataIterator:
"""A wrapper class for data iterators that adds replay capability.
......@@ -837,8 +1007,8 @@ class RerunDataIterator:
replay_data_iterator = RerunDataIterator(data_iterator)
"""
def __init__(self, iterable: Any, make_iterable: bool = True) -> None:
self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable
def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iterable
self.saved_microbatches: list[Any] = []
self.replaying: bool = False
self.replay_pos: int = 0
......@@ -870,7 +1040,7 @@ class RerunDataIterator:
self.replaying = False
self.saved_microbatches = []
def get_checkpoint_state(self) -> SerializableStateType:
def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the iterator as a serializable dict."""
return {
......@@ -879,7 +1049,7 @@ class RerunDataIterator:
'replay_pos': self.replay_pos,
}
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None:
def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict."""
self.saved_microbatches = state_dict['saved_microbatches']
......@@ -1051,7 +1221,7 @@ class RerunErrorInjector:
else:
raise RuntimeError("Should not be here")
def get_checkpoint_state(self) -> SerializableStateType:
def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the error injector as a serializable dict."""
return {
......@@ -1061,7 +1231,7 @@ class RerunErrorInjector:
'injected_error_type': self.injected_error_type,
}
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None:
def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict."""
self.error_injection_rate = state_dict['error_injection_rate']
......@@ -1107,7 +1277,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None:
def _safe_get_rank() -> int:
"""Internal function that safely checks and returns the rank of the caller."""
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if torch.distributed.is_initialized():
return torch.distributed.get_rank()
# If torch.distributed is not initialized, try to read environment variables.
try:
return int(os.environ.get("RANK", 0))
except (ValueError, TypeError):
return 0
def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float:
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -5,12 +5,14 @@
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Union
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
import torch
from torch import Tensor
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
......@@ -37,6 +39,9 @@ class MambaLayerSubmodules:
mixer: Union[ModuleSpec, type] = IdentityOp
mamba_bda: Union[ModuleSpec, type] = IdentityOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class MambaLayer(MegatronModule):
"""
......@@ -57,6 +62,7 @@ class MambaLayer(MegatronModule):
"""Initialize Mamba Layer."""
super().__init__(config)
self.config = config
self.submodules_config = submodules
self.layer_number = layer_number
self.residual_in_fp32 = residual_in_fp32
self.hidden_dropout = config.hidden_dropout
......@@ -114,3 +120,26 @@ class MambaLayer(MegatronModule):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
"""Allocate the inference cache."""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the mamba layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the mamba layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -985,6 +985,14 @@ class ColumnParallelLinear(torch.nn.Module):
"""Keep compatibility with TE state dict."""
return None
def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
......@@ -1206,3 +1214,11 @@ class RowParallelLinear(torch.nn.Module):
def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None
def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
File mode changed from 100755 to 100644
......@@ -5,10 +5,12 @@
import contextlib
import logging
from functools import partial
from typing import Union
import torch
from torch import _C
from torch.cuda import _lazy_call
from torch.cuda import _lazy_call, _lazy_init
from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
......@@ -21,17 +23,59 @@ from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_da
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
try:
import transformer_engine # pylint: disable=unused-import
HAVE_TE = True
except ModuleNotFoundError:
HAVE_TE = False
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1):
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda", clone: bool = False, graph_safe: bool = False
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU.
Arguments:
device (int): The gpu to retrieve the rng state
clone (bool): Whether to also clone the retrieved RNG state
graph_safe (bool): Get the rng state in a graph safe manner.
This function is adapted from torch.cuda.random.get_rng_state()"""
# if not using cuda graphs, just use the builtin pytorch function
if not graph_safe:
return torch.cuda.random.get_rng_state(device=device)
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if clone:
return default_generator.clone_state()
return default_generator.graphsafe_get_state()
def _set_cuda_rng_state(new_state: torch.Tensor, device: int = -1, graph_safe: bool = False):
"""Sets the random number generator state of the current GPU.
Argumentss:
Arguments:
new_state (torch.ByteTensor): The desired state
device (int): The gpu to retrieve the rng state
graph_safe (bool): Set the rng state in a graph safe manner.
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
......@@ -56,7 +100,12 @@ def _set_cuda_rng_state(new_state, device=-1):
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
# if graph capturing, set the rng state in a cudagraphable way
if graph_safe:
default_generator.graphsafe_set_state(new_state)
else:
default_generator.set_state(new_state)
_lazy_call(cb)
......@@ -82,8 +131,17 @@ class CudaRNGStatesTracker:
cuda state.
"""
def __init__(self):
def __init__(self, use_cudagraphable_rng=False):
self.reset()
self.use_cudagraphable_rng = use_cudagraphable_rng
if self.use_cudagraphable_rng:
assert (
hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state")
), "Tried using cudagraphs with RNG, however not detected in pytorch!"
def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
......@@ -125,13 +183,20 @@ class CudaRNGStatesTracker:
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
# If available, create the state in a graph safe manner
if self.use_cudagraphable_rng:
new_state = _get_cuda_rng_state(clone=True, graph_safe=True)
new_state.manual_seed(seed)
self.states_[name] = new_state
else:
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
......@@ -141,9 +206,9 @@ class CudaRNGStatesTracker:
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
_set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng)
# Record cpu RNG state
cpu_rng_state = torch.get_rng_state()
# Do the stuff we wanted to do.
......@@ -154,9 +219,9 @@ class CudaRNGStatesTracker:
if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
_set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng)
# RNG tracker object.
......@@ -164,35 +229,85 @@ _CUDA_RNG_STATE_TRACKER = None
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False
def initialize_rng_tracker(use_te_rng_tracker: bool = False):
def initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Create the RNG tracker. 'use_te_rng_tracker' determines whether to use
Megatron or TransformerEngine's implementation.
In particular, TransformerEngine's implementation is cudagraphable and supports FP8.
"""
global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return
if use_te_rng_tracker:
# Get the base tracker class
base_tracker = None
if HAVE_TE and use_te_rng_tracker:
if not is_te_min_version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
_CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker()
base_tracker = TECudaRNGStatesTracker
else:
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
base_tracker = partial(CudaRNGStatesTracker, use_cudagraphable_rng=use_cudagraphable_rng)
if inference_rng_tracker:
class InferenceCudaRNGStatesTracker(base_tracker):
"""RNG tracker for inference."""
def add(self, name, seed):
"""Mirrors the interface from the training RNG tracker."""
pass
def set_states(self, states):
"""Mirrors the interface from the training RNG tracker."""
pass
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Mirrors the interface from the training RNG tracker."""
return contextlib.nullcontext()
tracker_class = InferenceCudaRNGStatesTracker
else:
tracker_class = base_tracker
_CUDA_RNG_STATE_TRACKER = tracker_class()
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True
def get_cuda_rng_tracker(use_te_rng_tracker=False):
def get_cuda_rng_tracker(use_te_rng_tracker=False, inference_rng_tracker=False):
"""Get cuda rng tracker."""
initialize_rng_tracker(use_te_rng_tracker)
initialize_rng_tracker(use_te_rng_tracker, inference_rng_tracker)
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
def get_all_rng_states() -> bool:
"""Returns all generator states used by the current `CudaRNGStatesTracker`."""
assert (
_CUDA_RNG_STATE_TRACKER_INITIALIZED
), "Tried getting all rng states but RNG Tracker has not been initalized!"
if isinstance(_CUDA_RNG_STATE_TRACKER, CudaRNGStatesTracker):
return _CUDA_RNG_STATE_TRACKER.states_
# If TE is installed, check if we are using TE's RNG tracker
elif HAVE_TE and is_te_min_version("1.5.0"):
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
if isinstance(_CUDA_RNG_STATE_TRACKER, TECudaRNGStatesTracker):
from transformer_engine.pytorch.distributed import get_all_rng_states
return get_all_rng_states()
# no valid tracker, return an empty dict
else:
return {}
def model_parallel_cuda_manual_seed(seed, te_rng_tracker=False, inference_rng_tracker=False):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
......@@ -216,7 +331,7 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed.
data_parallel_seed = seed
initialize_rng_tracker()
initialize_rng_tracker(te_rng_tracker, inference_rng_tracker)
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
......@@ -239,14 +354,16 @@ class CheckpointFunction(torch.autograd.Function):
2) the states in the model parallel tracker are also properly tracked/set/reset.
"""
# pylint: disable=missing-function-docstring
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
"""Forward pass."""
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
......@@ -265,8 +382,10 @@ class CheckpointFunction(torch.autograd.Function):
return outputs
# pylint: disable=missing-function-docstring
@staticmethod
def backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
......@@ -280,7 +399,7 @@ class CheckpointFunction(torch.autograd.Function):
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
......@@ -302,7 +421,9 @@ class CheckpointFunction(torch.autograd.Function):
outputs = (outputs,)
# filter out non tensor outputs for backward pass
outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args)))
outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args))
)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
......
File mode changed from 100755 to 100644
......@@ -24,12 +24,20 @@ class TimerBase(ABC):
@abstractmethod
def start(self, barrier=False):
"""Start the timer."""
"""Start the timer.
Args:
barrier (bool, optional): Synchronizes ranks before starting. Defaults to False.
"""
pass
@abstractmethod
def stop(self, barrier=False):
"""Stop the timer."""
"""Stop the timer.
Args:
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
"""
pass
@abstractmethod
......@@ -39,7 +47,15 @@ class TimerBase(ABC):
@abstractmethod
def elapsed(self, reset=True, barrier=False):
"""Calculates the elapsed time."""
"""Calculates the elapsed time and restarts timer.
Args:
reset (bool, optional): Resets timer before restarting. Defaults to True.
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
Returns:
float: Elapsed time.
"""
pass
......@@ -59,7 +75,19 @@ class DummyTimer(TimerBase):
return
def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to calculate elapsed time')
raise Exception(
'dummy timer should not be used to calculate elapsed time, '
'check if timer\'s log_level <= self._log_level.'
)
def active_time(self):
"""Returns the cumulative duration the timer has been active.
Note: Not supported for DummyTimer.
"""
raise Exception(
'active timer should not be used to calculate elapsed time, '
'check if timer\'s log_level <= self._log_level.'
)
class Timer(TimerBase):
......@@ -155,7 +183,7 @@ class Timer(TimerBase):
return _elapsed
def active_time(self):
"""Returns the active time."""
"""Calculates the cumulative duration for which the timer has been active"""
return self._active_time
......@@ -397,8 +425,8 @@ class Timers:
reset: bool = True,
barrier: bool = False,
):
"""Write timers to a tensorboard writer. Note that we only report maximum time across ranks
to tensorboard.
"""Write timers to a tensorboard writer.
Note that we only report maximum time across ranks to tensorboard.
Args:
names (List[str]): Names of the timers to log.
......
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment