Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_mamba_stack_modelopt_spec(
local_core_attention: bool = False, remap_te_layernorm: bool = False
) -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map = {}
transformer_state_dict_keys_map = {}
if remap_te_layernorm:
mamba_state_dict_keys_map = {'norm.': 'mixer.in_proj.layer_norm_'}
transformer_state_dict_keys_map = {
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
}
mamba_layer = ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
norm=TENorm,
mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=ColumnParallelLinear, out_proj=RowParallelLinear
),
),
mamba_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=mamba_state_dict_keys_map,
),
)
core_attention = DotProductAttention if local_core_attention else TEDotProductAttention
attention_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=core_attention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
mlp_layer = ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
),
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map=transformer_state_dict_keys_map,
),
)
return ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
mamba_layer=mamba_layer, attention_layer=attention_layer, mlp_layer=mlp_layer
),
)
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import datetime
import inspect import inspect
import logging import logging
import math import math
import os import os
import random import random
import re
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Union from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Union
...@@ -12,6 +14,9 @@ from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Un ...@@ -12,6 +14,9 @@ from typing import Any, Callable, Iterable, NamedTuple, Optional, Set, Tuple, Un
import numpy as np import numpy as np
import torch import torch
import megatron.core.parallel_state as mpu
from megatron.core.dist_checkpointing.mapping import ShardedObject
"""DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE. """DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE.
The rerun state machine implementation in this file is alpha-level code to help The rerun state machine implementation in this file is alpha-level code to help
...@@ -34,6 +39,7 @@ EXIT_CODE_RESUME_TO_DISAMBIGUATE: int = 16 ...@@ -34,6 +39,7 @@ EXIT_CODE_RESUME_TO_DISAMBIGUATE: int = 16
EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17 EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17
SerializableStateType = Union[list, dict] SerializableStateType = Union[list, dict]
DataIteratorArgType = Optional[Union["RerunDataIterator", list["RerunDataIterator"]]]
class Caller(NamedTuple): class Caller(NamedTuple):
...@@ -105,6 +111,17 @@ class RerunState(Enum): ...@@ -105,6 +111,17 @@ class RerunState(Enum):
RERUNNING_AGAIN_FROM_CHECKPOINT = 5 RERUNNING_AGAIN_FROM_CHECKPOINT = 5
class RerunValidationStatus(str, Enum):
"""Enum representing the status of a record in the tracker log file"""
RERUN_DISABLED = 'rerun_disabled'
INITIAL_RUN = 'initial_run'
FIRST_RERUN_NOT_REPRODUCIBLE = 'first_rerun_not_reproducible'
FIRST_RERUN_REPRODUCIBLE = "first_rerun_reproducible"
SECOND_RERUN_NOT_REPRODUCIBLE = "second_rerun_not_reproducible"
SECOND_RERUN_REPRODUCIBLE = "second_rerun_reproducible"
COMPARISON_MATCH: float = 0.0 COMPARISON_MATCH: float = 0.0
COMPARISON_MISMATCH: float = math.inf COMPARISON_MISMATCH: float = math.inf
...@@ -121,6 +138,7 @@ class RerunStateMachine: ...@@ -121,6 +138,7 @@ class RerunStateMachine:
state_restore_func: optional function to restore the state saved by state_save_func. state_restore_func: optional function to restore the state saved by state_save_func.
mode: operating mode for the rerun state machine, default is disabled. mode: operating mode for the rerun state machine, default is disabled.
error_injector: optional result injection engine, default is no result injection. error_injector: optional result injection engine, default is no result injection.
result_rejected_tracker_filename: optional name of file tracking `result rejected` events.
Example usage: Example usage:
...@@ -170,6 +188,7 @@ class RerunStateMachine: ...@@ -170,6 +188,7 @@ class RerunStateMachine:
state_restore_func: Optional[Callable[[SerializableStateType], None]] = None, state_restore_func: Optional[Callable[[SerializableStateType], None]] = None,
mode: RerunMode = RerunMode.DISABLED, mode: RerunMode = RerunMode.DISABLED,
error_injector: Optional["RerunErrorInjector"] = None, error_injector: Optional["RerunErrorInjector"] = None,
result_rejected_tracker_filename: Optional[str] = None,
) -> None: ) -> None:
self.mode: RerunMode = mode self.mode: RerunMode = mode
self.state: RerunState = RerunState.NOT_RUNNING_YET self.state: RerunState = RerunState.NOT_RUNNING_YET
...@@ -192,6 +211,18 @@ class RerunStateMachine: ...@@ -192,6 +211,18 @@ class RerunStateMachine:
self.suspicious_node: str = None self.suspicious_node: str = None
self.suspicious_device: int = None self.suspicious_device: int = None
# Keep track of `result_rejected` events.
# Make sure the file can be written to and abort if not.
self.result_rejected_tracker_filename = result_rejected_tracker_filename
if self.result_rejected_tracker_filename is not None:
try:
with open(self.result_rejected_tracker_filename, 'a'):
pass
except Exception as e:
raise RuntimeError(
f"RerunStateMachine result validation log cannot be appended to! ({e})"
)
self.saved_state: Optional[SerializableStateType] = None self.saved_state: Optional[SerializableStateType] = None
self.state_save_func: Optional[Callable[[], SerializableStateType]] = state_save_func self.state_save_func: Optional[Callable[[], SerializableStateType]] = state_save_func
self.state_restore_func: Optional[Callable[[SerializableStateType], None]] = ( self.state_restore_func: Optional[Callable[[SerializableStateType], None]] = (
...@@ -199,16 +230,19 @@ class RerunStateMachine: ...@@ -199,16 +230,19 @@ class RerunStateMachine:
) )
self.data_iterator_checkpoints: Optional[list[SerializableStateType]] = None self.data_iterator_checkpoints: Optional[list[SerializableStateType]] = None
self.last_loss: Optional[float] = None self.large_value_counts: dict[str, int] = {}
self.max_values: dict[str, float] = {}
self.saved_results: dict[Call, Any] = {} self.saved_results: dict[Call, Any] = {}
self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats()) self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats())
logger.warning(f"RerunStateMachine initialized in mode {mode}") if _safe_get_rank() == 0:
logger.warning(f"RerunStateMachine initialized in mode {mode}")
def set_mode(self, mode: RerunMode) -> None: def set_mode(self, mode: RerunMode) -> None:
"""Method to set the operating mode""" """Method to set the operating mode"""
logger.warning(f"Setting RerunStateMachine mode {mode}") if _safe_get_rank() == 0:
logger.warning(f"Setting RerunStateMachine mode {mode}")
self.mode = mode self.mode = mode
def get_mode(self) -> RerunMode: def get_mode(self) -> RerunMode:
...@@ -216,9 +250,7 @@ class RerunStateMachine: ...@@ -216,9 +250,7 @@ class RerunStateMachine:
return self.mode return self.mode
def should_run_forward_backward( def should_run_forward_backward(self, data_iterator: DataIteratorArgType) -> bool:
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> bool:
"""Method instructing whether to (re)run the forward-backward pass. """Method instructing whether to (re)run the forward-backward pass.
Args: Args:
...@@ -243,30 +275,20 @@ class RerunStateMachine: ...@@ -243,30 +275,20 @@ class RerunStateMachine:
self.validation_counts = defaultdict(int) self.validation_counts = defaultdict(int)
data_iterators: list[RerunDataIterator] = [] data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
if self.mode != RerunMode.DISABLED and data_iterator is not None:
if not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
# Are we about to start the initial run? # Are we about to start the initial run?
if self.state == RerunState.NOT_RUNNING_YET: if self.state == RerunState.NOT_RUNNING_YET:
if self.mode == RerunMode.DISABLED: if self.mode == RerunMode.DISABLED:
self.state = RerunState.INITIAL_RUN self.state = RerunState.INITIAL_RUN
self.current_iteration += 1 # Increment self.current_iteration for reporting.
return True return True
if self.data_iterator_checkpoints is not None: if self.data_iterator_checkpoints is not None:
assert ( assert len(self.data_iterator_checkpoints) == len(
len(self.data_iterator_checkpoints) == len(data_iterators), data_iterators
"data_iterator has different length than checkpointed data iterator", ), "data iterator has different length than checkpointed data iterator"
)
for i, d in enumerate(data_iterators): for i, d in enumerate(data_iterators):
d.set_checkpoint_state(self.data_iterator_checkpoints[i]) d.load_state_dict(self.data_iterator_checkpoints[i])
self.data_iterator_checkpoints = None self.data_iterator_checkpoints = None
self._save_state() self._save_state()
if data_iterators: if data_iterators:
...@@ -464,10 +486,28 @@ class RerunStateMachine: ...@@ -464,10 +486,28 @@ class RerunStateMachine:
verifying the result is the same. verifying the result is the same.
""" """
# Skip the validation check if the state machine is disabled or if we haven't run # If reruns are disabled, still validate the result and throw a RuntimeError if it is
# a full iteration yet. We cannot guarantee that a checkpoint can be taken before the # rejected. This is a backward-compatible behavior.
# optimizer has been stepped at least once. if self.mode == RerunMode.DISABLED:
if self.mode == RerunMode.DISABLED or self.current_iteration < 1: result_rejected: bool = rejection_func(result)
if result_rejected:
self._log_validation_error_to_file(
status=RerunValidationStatus.RERUN_DISABLED, result=result, message=message
)
rank: int = _safe_get_rank()
node: str = os.uname()[1]
device: int = torch.cuda.current_device()
full_message: str = (
f"Rank {rank}, node {node}, device {device}, "
f"iteration {self.current_iteration}: "
f"Unexpected result {result} (message='{message}')"
)
raise RuntimeError(full_message)
return
# Skip the validation on the first iteration, as we cannot guarantee a checkpoint can be
# taken before the optimizer has been stepped at least once.
if self.current_iteration < 1:
return return
if comparison_func is None: if comparison_func is None:
...@@ -523,6 +563,9 @@ class RerunStateMachine: ...@@ -523,6 +563,9 @@ class RerunStateMachine:
self.failed_validation_call = validation_call self.failed_validation_call = validation_call
self.initial_result = result self.initial_result = result
self.rerun_requested = True self.rerun_requested = True
self._log_validation_error_to_file(
status=RerunValidationStatus.INITIAL_RUN, result=result, message=message
)
logger.error( logger.error(
f"Unexpected result {result} at {validation_call.caller.filename} " f"Unexpected result {result} at {validation_call.caller.filename} "
f"line {validation_call.caller.lineno}, " f"line {validation_call.caller.lineno}, "
...@@ -547,6 +590,11 @@ class RerunStateMachine: ...@@ -547,6 +590,11 @@ class RerunStateMachine:
"First rerun: unexpected result is not reproducible within the tolerance " "First rerun: unexpected result is not reproducible within the tolerance "
f"({result} != {self.initial_result})" f"({result} != {self.initial_result})"
) )
self._log_validation_error_to_file(
status=RerunValidationStatus.FIRST_RERUN_NOT_REPRODUCIBLE,
result=result,
message=message,
)
log_failure("Possible transient error!") log_failure("Possible transient error!")
else: else:
self.checkpoint_requested = True self.checkpoint_requested = True
...@@ -554,6 +602,11 @@ class RerunStateMachine: ...@@ -554,6 +602,11 @@ class RerunStateMachine:
# rerunning on the same GPU when we resume from the checkpoint. # rerunning on the same GPU when we resume from the checkpoint.
self.suspicious_node = os.uname()[1] self.suspicious_node = os.uname()[1]
self.suspicious_device = torch.cuda.current_device() self.suspicious_device = torch.cuda.current_device()
self._log_validation_error_to_file(
status=RerunValidationStatus.FIRST_RERUN_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning( logger.warning(
"First rerun: unexpected result is reproducible within the tolerance " "First rerun: unexpected result is reproducible within the tolerance "
f"({result} = {self.initial_result}). " f"({result} = {self.initial_result}). "
...@@ -571,12 +624,22 @@ class RerunStateMachine: ...@@ -571,12 +624,22 @@ class RerunStateMachine:
) )
self.restart_again_requested = True self.restart_again_requested = True
elif comparison > tolerance: elif comparison > tolerance:
self._log_validation_error_to_file(
status=RerunValidationStatus.SECOND_RERUN_NOT_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning( logger.warning(
"Second rerun: unexpected result is not reproducible on a different GPU, " "Second rerun: unexpected result is not reproducible on a different GPU, "
f"therefore was likely incorrect ({result} != {self.initial_result})" f"therefore was likely incorrect ({result} != {self.initial_result})"
) )
log_failure("Possible persistent error!") log_failure("Possible persistent error!")
else: else:
self._log_validation_error_to_file(
status=RerunValidationStatus.SECOND_RERUN_REPRODUCIBLE,
result=result,
message=message,
)
logger.warning( logger.warning(
"Second rerun: unexpected result is reproducible on a different GPU, " "Second rerun: unexpected result is reproducible on a different GPU, "
f"therefore it was likely correct ({result} = {self.initial_result})" f"therefore it was likely correct ({result} = {self.initial_result})"
...@@ -587,13 +650,31 @@ class RerunStateMachine: ...@@ -587,13 +650,31 @@ class RerunStateMachine:
else: else:
raise RuntimeError("Should not be here") raise RuntimeError("Should not be here")
def is_spiky_loss(self, loss_tensor: torch.Tensor, threshold: float) -> bool: def is_unexpectedly_large(
"""Helper method to estimate whether a loss is spiky. self,
result: torch.Tensor,
threshold: float,
context: str,
num_samples: int = 100,
resample: bool = False,
) -> bool:
"""Helper method to estimate whether a result is unexpectedly large.
Some calculation errors manifest themselves as results with unexpectedly large
exponents, e.g. spiky loss or grads. This method keeps track of a value over time
and flags it if it exceeds a certain threshold expressed as a multiple factor of
the max value observed.
Args: Args:
loss_tensor: a zero-dim tensor containing the current loss. loss_tensor: a zero-dim tensor containing the current loss.
threshold: a float representing the minimum relative variation threshold: a float representing the minimum trigger threshold
characterizing a spiky loss (e.g. 0.1 means +/- 10%). e.g. 10 means > 10x max absolute value observed.
context: a string identifying the value. This is used to differentiate
between different invokations of validate_results targetting different
values, e.g. loss and grads.
num_samples: the sample size used to estimate the max value.
Default is 100 value samples.
reset: whether to resample the max value. Default is False.
Returns: Returns:
A boolean telling whether the current loss deviates from the previous A boolean telling whether the current loss deviates from the previous
loss by a factor greater than the threshold loss by a factor greater than the threshold
...@@ -612,37 +693,42 @@ class RerunStateMachine: ...@@ -612,37 +693,42 @@ class RerunStateMachine:
loss = loss_fn(outputs) loss = loss_fn(outputs)
rerun_machine.validate_result( rerun_machine.validate_result(
result=loss, result=loss,
rejection_func=partial(rerun_machine.is_spiky_loss, threshold=0.1), rejection_func=partial(
rerun_machine.is_unexpectedly_large,
threshold=10,
context="loss",
),
message="Spiky loss", message="Spiky loss",
tolerance=0.0, tolerance=0.0,
fatal=False, fatal=False,
) )
""" """
loss: float = loss_tensor.item() value: float = math.fabs(result.item())
result: bool = False # Ignore NaNs and Infs. They should be checked separately.
if self.last_loss is not None: if math.isnan(value) or math.isinf(value):
# Ignore NaNs, and consider infinite loss as spiky. return False
if math.isnan(loss) or math.isnan(self.last_loss):
result = False if resample or context not in self.large_value_counts:
elif math.isinf(loss) or math.isinf(self.last_loss): self.large_value_counts[context] = 0
result = True if self.large_value_counts[context] < num_samples:
else: self.large_value_counts[context] += 1
result = math.fabs(loss - self.last_loss) / self.last_loss >= threshold self.max_values[context] = max(self.max_values.get(context, 0.0), value)
self.last_loss = loss if self.large_value_counts[context] == num_samples:
return result logger.warning(f"Max value for {context}: {self.max_values[context]}")
return False
return value >= self.max_values[context] * threshold
def get_checkpoint_state( def state_dict(self, data_iterator: DataIteratorArgType, use_dist_ckpt: bool) -> dict[str, Any]:
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> list[dict[str, Any]]:
"""Method that returns a state dict to be checkpointed. """Method that returns a state dict to be checkpointed.
Args: Args:
data_iterator: the data iterator that needs to be checkpointed (or None data_iterator: the data iterator that needs to be checkpointed (or None
if this checkpoint is not requested by the rerun state machine). if this checkpoint is not requested by the rerun state machine).
use_dist_ckpt: generate a distributed checkpoint.
Returns: Returns:
A list of state dicts, each state dict representing the rerun state machine A state dict representing the rerun state machine.
for one rank.
Example usage: Example usage:
...@@ -651,65 +737,63 @@ class RerunStateMachine: ...@@ -651,65 +737,63 @@ class RerunStateMachine:
... ...
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
checkpoint['rerun_state_machine'] = ( checkpoint['rerun_state_machine'] = (
rerun_state_machine.get_checkpoint_state(data_iterator) rerun_state_machine.state_dict(data_iterator, False)
) )
... ...
return checkpoint return checkpoint
""" """
data_iterators: list[RerunDataIterator] data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif isinstance(data_iterator, (list, tuple)):
data_iterators = data_iterator
else:
data_iterators = [data_iterator] if data_iterator is not None else []
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
"data iterator is not wrapped with RerunDataIterator",
)
state: dict[str, Any] = { # The RerunStateMachine state is different across all ranks. Therefore it needs to be
# checkpointed using a ShardedObject. However, we keep the common state in the non-sharded
# (common) checkpoint. This allows us to verify whether a checkpoint contains a
# RerunStateMachine state by checking the common checkpoint.
state_dict: dict[str, Any] = {
'mode': self.mode, 'mode': self.mode,
'state': self.state, 'sharded': {
'current_iteration': self.current_iteration, 'state': self.state,
'rerun_requested': self.rerun_requested, 'current_iteration': self.current_iteration,
'checkpoint_requested': self.checkpoint_requested, 'rerun_requested': self.rerun_requested,
'restart_again_requested': self.restart_again_requested, 'checkpoint_requested': self.checkpoint_requested,
'continue_requested': self.continue_requested, 'restart_again_requested': self.restart_again_requested,
# logged_sdc_enabled should not be saved (set at the job startup time). 'continue_requested': self.continue_requested,
'error_injector_checkpoint': self.error_injector.get_checkpoint_state(), # logged_sdc_enabled should not be saved (set at the job startup time).
# validation_counts should not be saved (reset at the beginning of the training loop). 'error_injector_checkpoint': self.error_injector.state_dict(),
'failed_validation_call': self.failed_validation_call, # validation_counts should not be saved (reset at start of training loop).
'initial_result': self.initial_result, 'failed_validation_call': self.failed_validation_call,
'suspicious_node': self.suspicious_node, 'initial_result': self.initial_result,
'suspicious_device': self.suspicious_device, 'suspicious_node': self.suspicious_node,
# No need to save saved_state (RNG state already captured in checkpoint). 'suspicious_device': self.suspicious_device,
'data_iterator_checkpoints': ( # No need to save saved_state (RNG state already captured in checkpoint).
[d.get_checkpoint_state() for d in data_iterators] if data_iterators else None 'data_iterator_checkpoints': (
), [d.state_dict() for d in data_iterators] if data_iterators else None
'last_loss': self.last_loss, ),
# No need to save saved_results and stats (resets when job resumes). 'large_value_counts': self.large_value_counts,
'max_values': self.max_values,
# No need to save saved_results and stats (resets when job resumes).
},
} }
state_list: list[dict[str, Any]] if use_dist_ckpt:
if ( pp_rank = mpu.get_pipeline_model_parallel_rank()
torch.distributed.is_initialized() pp_size = mpu.get_pipeline_model_parallel_world_size()
and torch.distributed.get_world_size() > 1 tp_rank = mpu.get_tensor_model_parallel_rank()
and self.mode != RerunMode.DISABLED tp_size = mpu.get_tensor_model_parallel_world_size()
): state_dict['sharded'] = ShardedObject(
state_list = [None for i in range(torch.distributed.get_world_size())] 'rerun_state_machine_state',
torch.distributed.all_gather_object(state_list, state) state_dict['sharded'],
else: (pp_size, tp_size),
state_list = [state] (pp_rank, tp_rank),
return state_list replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
)
return state_dict
def set_checkpoint_state(self, state_list: list[dict[str, Any]]) -> None: def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Method that restores the state from a checkpoint. """Method that restores the state from a checkpoint.
Args: Args:
state_list: the list of state dicts saved in the checkpoint and originally state_dict: the state dict saved in the checkpoint and originally
obtained from get_checkpoint_state(). obtained from state_dict().
Returns: Returns:
None None
...@@ -719,31 +803,59 @@ class RerunStateMachine: ...@@ -719,31 +803,59 @@ class RerunStateMachine:
... ...
if 'rerun_state_machine' in checkpoint: if 'rerun_state_machine' in checkpoint:
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
rerun_state_machine.set_checkpoint_state(checkpoint['rerun_state_machine']) rerun_state_machine.load_state_dict(checkpoint['rerun_state_machine'])
""" """
if self.mode == RerunMode.DISABLED: if self.mode == RerunMode.DISABLED:
if _safe_get_rank() == 0:
logger.warning(
"RerunStateMachine disabled via CLI, ignoring machine state saved in checkpoint"
)
return return
rank: int = _safe_get_rank() if state_dict['mode'] == RerunMode.DISABLED:
if rank == 0: if _safe_get_rank() == 0:
logger.warning(
"RerunStateMachine disabled in checkpoint but enabled via CLI, "
"ignoring machine state saved in checkpoint"
)
return
if _safe_get_rank() == 0:
logger.warning( logger.warning(
"Getting RerunStaeMachine state from checkpoint, args rerun options ignored" "Getting RerunStateMachine state from checkpoint, CLI rerun args ignored"
) )
state = state_list[rank] self.mode = state_dict['mode']
self.mode = state['mode'] sharded_dict = state_dict['sharded']
self.state = state['state'] self.state = sharded_dict['state']
self.current_iteration = state['current_iteration'] self.current_iteration = sharded_dict['current_iteration']
self.rerun_requested = state['rerun_requested'] self.rerun_requested = sharded_dict['rerun_requested']
self.checkpoint_requested = state['checkpoint_requested'] self.checkpoint_requested = sharded_dict['checkpoint_requested']
self.restart_again_requested = state['restart_again_requested'] self.restart_again_requested = sharded_dict['restart_again_requested']
self.continue_requested = state['continue_requested'] self.continue_requested = sharded_dict['continue_requested']
self.error_injector.set_checkpoint_state(state['error_injector_checkpoint']) self.error_injector.load_state_dict(sharded_dict['error_injector_checkpoint'])
self.failed_validation_call = state['failed_validation_call'] self.failed_validation_call = sharded_dict['failed_validation_call']
self.initial_result = state['initial_result'] self.initial_result = sharded_dict['initial_result']
self.suspicious_node = state['suspicious_node'] self.suspicious_node = sharded_dict['suspicious_node']
self.suspicious_device = state['suspicious_device'] self.suspicious_device = sharded_dict['suspicious_device']
self.data_iterator_checkpoints = state['data_iterator_checkpoints'] self.data_iterator_checkpoints = sharded_dict['data_iterator_checkpoints']
self.last_loss = state['last_loss'] self.large_value_counts = sharded_dict['large_value_counts']
self.max_values = sharded_dict['max_values']
def _sanitize_data_iterators(
self, data_iterator: DataIteratorArgType
) -> list["RerunDataIterator"]:
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
data_iterators = [d for d in data_iterators if d is not None]
for d in data_iterators:
assert isinstance(
d, RerunDataIterator
), "data iterator is not wrapped with RerunDataIterator"
return data_iterators
def _get_validation_call_info(self) -> Call: def _get_validation_call_info(self) -> Call:
"""Internal method to get the context about the caller to validate_result().""" """Internal method to get the context about the caller to validate_result()."""
...@@ -817,6 +929,64 @@ class RerunStateMachine: ...@@ -817,6 +929,64 @@ class RerunStateMachine:
logger.info(f" From {caller.filename}, line {caller.lineno}:") logger.info(f" From {caller.filename}, line {caller.lineno}:")
logger.info(f" {stats.print_stats()}") logger.info(f" {stats.print_stats()}")
def _log_validation_error_to_file(
self, status: RerunValidationStatus, result: Any, message: str
) -> None:
if self.result_rejected_tracker_filename is not None:
# Append to log.
try:
rank: int = _safe_get_rank()
node: str = os.uname()[1]
device: int = torch.cuda.current_device()
with open(self.result_rejected_tracker_filename, 'a') as f:
print(
f"ts={datetime.datetime.now()} node={node} device={device} "
f"jobID={os.getenv('SLURM_JOBID', 'N/A')} rank={rank} "
f"iteration={self.current_iteration} status={status} result={result} "
f"message='{message}'",
file=f,
)
except Exception as e:
logger.error(f"Could not log validation error! ({e})")
@classmethod
def get_skipped_iterations_from_tracker_file(cls, tracker_file_name: str) -> list[int]:
"""Get list of iterations to skip from results recorded in tracker file. If an
"abnormality" (e.g., NaN or infinity in gradient) is seen more than once on a
given rank and iteration, the corresponding iteration is skipped.
Args:
tracker_file_name (str): Name of tracker file.
Returns:
list[int]: List of iterations to skip.
"""
iterations_to_skip: set[int] = set()
seen: set[Tuple[int, int]]
regex = r"ts=.+ node=.+ device=.+ jobID=.+ rank=(.+) iteration=(.+) status=(.+) .+"
try:
with open(tracker_file_name, 'r') as f:
for line in f.readlines():
match = re.search(regex, line)
if match:
rank = int(match[1])
iteration = int(match[2])
status = match[3]
# Skip an iteration if:
# - Reruns were disabled and it has failed on the same rank twice.
# or
# - Reruns were enabled and it was reproducible on the 2nd rerun
if status == RerunValidationStatus.RERUN_DISABLED:
if (rank, iteration) in seen:
iterations_to_skip.add(iteration)
else:
seen.add((rank, iteration))
elif status == RerunValidationStatus.SECOND_RERUN_REPRODUCIBLE:
iterations_to_skip.add(iteration)
except Exception as e:
logger.error(f"Could not parse iterations to skip in tracker file! ({e})")
return sorted(iterations_to_skip)
class RerunDataIterator: class RerunDataIterator:
"""A wrapper class for data iterators that adds replay capability. """A wrapper class for data iterators that adds replay capability.
...@@ -837,8 +1007,8 @@ class RerunDataIterator: ...@@ -837,8 +1007,8 @@ class RerunDataIterator:
replay_data_iterator = RerunDataIterator(data_iterator) replay_data_iterator = RerunDataIterator(data_iterator)
""" """
def __init__(self, iterable: Any, make_iterable: bool = True) -> None: def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable: Iterable[Any] = iter(iterable) if make_iterable else iterable self.iterable: Iterable[Any] = iterable
self.saved_microbatches: list[Any] = [] self.saved_microbatches: list[Any] = []
self.replaying: bool = False self.replaying: bool = False
self.replay_pos: int = 0 self.replay_pos: int = 0
...@@ -870,7 +1040,7 @@ class RerunDataIterator: ...@@ -870,7 +1040,7 @@ class RerunDataIterator:
self.replaying = False self.replaying = False
self.saved_microbatches = [] self.saved_microbatches = []
def get_checkpoint_state(self) -> SerializableStateType: def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the iterator as a serializable dict.""" """Method to capture the state of the iterator as a serializable dict."""
return { return {
...@@ -879,7 +1049,7 @@ class RerunDataIterator: ...@@ -879,7 +1049,7 @@ class RerunDataIterator:
'replay_pos': self.replay_pos, 'replay_pos': self.replay_pos,
} }
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None: def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict.""" """Method to restore the state saved as a serializable dict."""
self.saved_microbatches = state_dict['saved_microbatches'] self.saved_microbatches = state_dict['saved_microbatches']
...@@ -1051,7 +1221,7 @@ class RerunErrorInjector: ...@@ -1051,7 +1221,7 @@ class RerunErrorInjector:
else: else:
raise RuntimeError("Should not be here") raise RuntimeError("Should not be here")
def get_checkpoint_state(self) -> SerializableStateType: def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the error injector as a serializable dict.""" """Method to capture the state of the error injector as a serializable dict."""
return { return {
...@@ -1061,7 +1231,7 @@ class RerunErrorInjector: ...@@ -1061,7 +1231,7 @@ class RerunErrorInjector:
'injected_error_type': self.injected_error_type, 'injected_error_type': self.injected_error_type,
} }
def set_checkpoint_state(self, state_dict: SerializableStateType) -> None: def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict.""" """Method to restore the state saved as a serializable dict."""
self.error_injection_rate = state_dict['error_injection_rate'] self.error_injection_rate = state_dict['error_injection_rate']
...@@ -1107,7 +1277,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None: ...@@ -1107,7 +1277,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None:
def _safe_get_rank() -> int: def _safe_get_rank() -> int:
"""Internal function that safely checks and returns the rank of the caller.""" """Internal function that safely checks and returns the rank of the caller."""
return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 if torch.distributed.is_initialized():
return torch.distributed.get_rank()
# If torch.distributed is not initialized, try to read environment variables.
try:
return int(os.environ.get("RANK", 0))
except (ValueError, TypeError):
return 0
def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float: def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float:
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
# This source code is licensed under the Apache license found in the # This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Union from typing import Dict, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.spec_utils import ModuleSpec, build_module
...@@ -37,6 +39,9 @@ class MambaLayerSubmodules: ...@@ -37,6 +39,9 @@ class MambaLayerSubmodules:
mixer: Union[ModuleSpec, type] = IdentityOp mixer: Union[ModuleSpec, type] = IdentityOp
mamba_bda: Union[ModuleSpec, type] = IdentityOp mamba_bda: Union[ModuleSpec, type] = IdentityOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class MambaLayer(MegatronModule): class MambaLayer(MegatronModule):
""" """
...@@ -57,6 +62,7 @@ class MambaLayer(MegatronModule): ...@@ -57,6 +62,7 @@ class MambaLayer(MegatronModule):
"""Initialize Mamba Layer.""" """Initialize Mamba Layer."""
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.submodules_config = submodules
self.layer_number = layer_number self.layer_number = layer_number
self.residual_in_fp32 = residual_in_fp32 self.residual_in_fp32 = residual_in_fp32
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
...@@ -114,3 +120,26 @@ class MambaLayer(MegatronModule): ...@@ -114,3 +120,26 @@ class MambaLayer(MegatronModule):
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
"""Allocate the inference cache.""" """Allocate the inference cache."""
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype) return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the mamba layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the mamba layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -985,6 +985,14 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -985,6 +985,14 @@ class ColumnParallelLinear(torch.nn.Module):
"""Keep compatibility with TE state dict.""" """Keep compatibility with TE state dict."""
return None return None
def __repr__(self):
tp = self.output_size // self.output_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
class RowParallelLinear(torch.nn.Module): class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism. """Linear layer with row parallelism.
...@@ -1206,3 +1214,11 @@ class RowParallelLinear(torch.nn.Module): ...@@ -1206,3 +1214,11 @@ class RowParallelLinear(torch.nn.Module):
def get_extra_state(self) -> None: def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict.""" """Keep compatibility with TE state dict."""
return None return None
def __repr__(self):
tp = self.input_size // self.input_size_per_partition
use_bias = self.bias is not None and self.bias is True
return (
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)
File mode changed from 100755 to 100644
...@@ -5,10 +5,12 @@ ...@@ -5,10 +5,12 @@
import contextlib import contextlib
import logging import logging
from functools import partial
from typing import Union
import torch import torch
from torch import _C from torch import _C
from torch.cuda import _lazy_call from torch.cuda import _lazy_call, _lazy_init
from torch.cuda import device as device_ctx_manager from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
...@@ -21,17 +23,59 @@ from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_da ...@@ -21,17 +23,59 @@ from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_da
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
try:
import transformer_engine # pylint: disable=unused-import
HAVE_TE = True
except ModuleNotFoundError:
HAVE_TE = False
# Default name for the model parallel rng tracker. # Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng' _EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng' _DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1): def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda", clone: bool = False, graph_safe: bool = False
) -> torch.Tensor:
"""Return the random number generator state of the specified GPU.
Arguments:
device (int): The gpu to retrieve the rng state
clone (bool): Whether to also clone the retrieved RNG state
graph_safe (bool): Get the rng state in a graph safe manner.
This function is adapted from torch.cuda.random.get_rng_state()"""
# if not using cuda graphs, just use the builtin pytorch function
if not graph_safe:
return torch.cuda.random.get_rng_state(device=device)
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if clone:
return default_generator.clone_state()
return default_generator.graphsafe_get_state()
def _set_cuda_rng_state(new_state: torch.Tensor, device: int = -1, graph_safe: bool = False):
"""Sets the random number generator state of the current GPU. """Sets the random number generator state of the current GPU.
Argumentss: Arguments:
new_state (torch.ByteTensor): The desired state new_state (torch.ByteTensor): The desired state
device (int): The gpu to retrieve the rng state
graph_safe (bool): Set the rng state in a graph safe manner.
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases. major performance issues for +4 GPU cases.
...@@ -56,7 +100,12 @@ def _set_cuda_rng_state(new_state, device=-1): ...@@ -56,7 +100,12 @@ def _set_cuda_rng_state(new_state, device=-1):
if idx is None: if idx is None:
idx = torch.cuda.current_device() idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx] default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
# if graph capturing, set the rng state in a cudagraphable way
if graph_safe:
default_generator.graphsafe_set_state(new_state)
else:
default_generator.set_state(new_state)
_lazy_call(cb) _lazy_call(cb)
...@@ -82,8 +131,17 @@ class CudaRNGStatesTracker: ...@@ -82,8 +131,17 @@ class CudaRNGStatesTracker:
cuda state. cuda state.
""" """
def __init__(self): def __init__(self, use_cudagraphable_rng=False):
self.reset() self.reset()
self.use_cudagraphable_rng = use_cudagraphable_rng
if self.use_cudagraphable_rng:
assert (
hasattr(torch.cuda.CUDAGraph, "register_generator_state")
and hasattr(torch.Generator, "graphsafe_set_state")
and hasattr(torch.Generator, "graphsafe_get_state")
and hasattr(torch.Generator, "clone_state")
), "Tried using cudagraphs with RNG, however not detected in pytorch!"
def is_initialized(self): def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states().""" """Checks if the internal RNG state has been set wirth set_states()."""
...@@ -125,13 +183,20 @@ class CudaRNGStatesTracker: ...@@ -125,13 +183,20 @@ class CudaRNGStatesTracker:
# Check that state is not already defined. # Check that state is not already defined.
if name in self.states_: if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name)) raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state() # If available, create the state in a graph safe manner
# Set the new state and store it. if self.use_cudagraphable_rng:
torch.cuda.manual_seed(seed) new_state = _get_cuda_rng_state(clone=True, graph_safe=True)
self.states_[name] = torch.cuda.get_rng_state() new_state.manual_seed(seed)
# Reset rng state to what it was. self.states_[name] = new_state
_set_cuda_rng_state(orig_rng_state) else:
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager @contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
...@@ -141,9 +206,9 @@ class CudaRNGStatesTracker: ...@@ -141,9 +206,9 @@ class CudaRNGStatesTracker:
if name not in self.states_: if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name)) raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state. # Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state() orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# Set rng state to the desired one # Set rng state to the desired one
_set_cuda_rng_state(self.states_[name]) _set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng)
# Record cpu RNG state # Record cpu RNG state
cpu_rng_state = torch.get_rng_state() cpu_rng_state = torch.get_rng_state()
# Do the stuff we wanted to do. # Do the stuff we wanted to do.
...@@ -154,9 +219,9 @@ class CudaRNGStatesTracker: ...@@ -154,9 +219,9 @@ class CudaRNGStatesTracker:
if not torch.all(cpu_rng_state == torch.get_rng_state()).item(): if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context') logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
# Update the current rng state for later use. # Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state() self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# And set the state to the original state we started with. # And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state) _set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng)
# RNG tracker object. # RNG tracker object.
...@@ -164,35 +229,85 @@ _CUDA_RNG_STATE_TRACKER = None ...@@ -164,35 +229,85 @@ _CUDA_RNG_STATE_TRACKER = None
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False _CUDA_RNG_STATE_TRACKER_INITIALIZED = False
def initialize_rng_tracker(use_te_rng_tracker: bool = False): def initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
):
"""Create the RNG tracker. 'use_te_rng_tracker' determines whether to use """Create the RNG tracker. 'use_te_rng_tracker' determines whether to use
Megatron or TransformerEngine's implementation. Megatron or TransformerEngine's implementation.
In particular, TransformerEngine's implementation is cudagraphable and supports FP8. In particular, TransformerEngine's implementation is cudagraphable and supports FP8.
""" """
global _CUDA_RNG_STATE_TRACKER global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED: if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return return
if use_te_rng_tracker: # Get the base tracker class
base_tracker = None
if HAVE_TE and use_te_rng_tracker:
if not is_te_min_version("1.5.0"): if not is_te_min_version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5") raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
_CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker() base_tracker = TECudaRNGStatesTracker
else: else:
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() base_tracker = partial(CudaRNGStatesTracker, use_cudagraphable_rng=use_cudagraphable_rng)
if inference_rng_tracker:
class InferenceCudaRNGStatesTracker(base_tracker):
"""RNG tracker for inference."""
def add(self, name, seed):
"""Mirrors the interface from the training RNG tracker."""
pass
def set_states(self, states):
"""Mirrors the interface from the training RNG tracker."""
pass
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Mirrors the interface from the training RNG tracker."""
return contextlib.nullcontext()
tracker_class = InferenceCudaRNGStatesTracker
else:
tracker_class = base_tracker
_CUDA_RNG_STATE_TRACKER = tracker_class()
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True _CUDA_RNG_STATE_TRACKER_INITIALIZED = True
def get_cuda_rng_tracker(use_te_rng_tracker=False): def get_cuda_rng_tracker(use_te_rng_tracker=False, inference_rng_tracker=False):
"""Get cuda rng tracker.""" """Get cuda rng tracker."""
initialize_rng_tracker(use_te_rng_tracker) initialize_rng_tracker(use_te_rng_tracker, inference_rng_tracker)
return _CUDA_RNG_STATE_TRACKER return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed): def get_all_rng_states() -> bool:
"""Returns all generator states used by the current `CudaRNGStatesTracker`."""
assert (
_CUDA_RNG_STATE_TRACKER_INITIALIZED
), "Tried getting all rng states but RNG Tracker has not been initalized!"
if isinstance(_CUDA_RNG_STATE_TRACKER, CudaRNGStatesTracker):
return _CUDA_RNG_STATE_TRACKER.states_
# If TE is installed, check if we are using TE's RNG tracker
elif HAVE_TE and is_te_min_version("1.5.0"):
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
if isinstance(_CUDA_RNG_STATE_TRACKER, TECudaRNGStatesTracker):
from transformer_engine.pytorch.distributed import get_all_rng_states
return get_all_rng_states()
# no valid tracker, return an empty dict
else:
return {}
def model_parallel_cuda_manual_seed(seed, te_rng_tracker=False, inference_rng_tracker=False):
"""Initialize model parallel cuda seed. """Initialize model parallel cuda seed.
This function should be called after the model parallel is This function should be called after the model parallel is
...@@ -216,7 +331,7 @@ def model_parallel_cuda_manual_seed(seed): ...@@ -216,7 +331,7 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
initialize_rng_tracker() initialize_rng_tracker(te_rng_tracker, inference_rng_tracker)
_CUDA_RNG_STATE_TRACKER.reset() _CUDA_RNG_STATE_TRACKER.reset()
# Set the default state. # Set the default state.
torch.cuda.manual_seed(data_parallel_seed) torch.cuda.manual_seed(data_parallel_seed)
...@@ -239,14 +354,16 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -239,14 +354,16 @@ class CheckpointFunction(torch.autograd.Function):
2) the states in the model parallel tracker are also properly tracked/set/reset. 2) the states in the model parallel tracker are also properly tracked/set/reset.
""" """
# pylint: disable=missing-function-docstring
@staticmethod @staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args): def forward(ctx, run_function, distribute_saved_activations, *args):
"""Forward pass."""
ctx.run_function = run_function ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states. # Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state = _get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad(): with torch.no_grad():
...@@ -265,8 +382,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -265,8 +382,10 @@ class CheckpointFunction(torch.autograd.Function):
return outputs return outputs
# pylint: disable=missing-function-docstring
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
"""Backward pass."""
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError( raise RuntimeError(
"Checkpointing is not compatible with .grad(), " "Checkpointing is not compatible with .grad(), "
...@@ -280,7 +399,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -280,7 +399,7 @@ class CheckpointFunction(torch.autograd.Function):
# Store the current states. # Store the current states.
bwd_cpu_rng_state = torch.get_rng_state() bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state = _get_cuda_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
...@@ -302,7 +421,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -302,7 +421,9 @@ class CheckpointFunction(torch.autograd.Function):
outputs = (outputs,) outputs = (outputs,)
# filter out non tensor outputs for backward pass # filter out non tensor outputs for backward pass
outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args))) outputs, args = zip(
*filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args))
)
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads return (None, None) + grads
......
File mode changed from 100755 to 100644
...@@ -24,12 +24,20 @@ class TimerBase(ABC): ...@@ -24,12 +24,20 @@ class TimerBase(ABC):
@abstractmethod @abstractmethod
def start(self, barrier=False): def start(self, barrier=False):
"""Start the timer.""" """Start the timer.
Args:
barrier (bool, optional): Synchronizes ranks before starting. Defaults to False.
"""
pass pass
@abstractmethod @abstractmethod
def stop(self, barrier=False): def stop(self, barrier=False):
"""Stop the timer.""" """Stop the timer.
Args:
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
"""
pass pass
@abstractmethod @abstractmethod
...@@ -39,7 +47,15 @@ class TimerBase(ABC): ...@@ -39,7 +47,15 @@ class TimerBase(ABC):
@abstractmethod @abstractmethod
def elapsed(self, reset=True, barrier=False): def elapsed(self, reset=True, barrier=False):
"""Calculates the elapsed time.""" """Calculates the elapsed time and restarts timer.
Args:
reset (bool, optional): Resets timer before restarting. Defaults to True.
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
Returns:
float: Elapsed time.
"""
pass pass
...@@ -59,7 +75,19 @@ class DummyTimer(TimerBase): ...@@ -59,7 +75,19 @@ class DummyTimer(TimerBase):
return return
def elapsed(self, reset=True, barrier=False): def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to calculate elapsed time') raise Exception(
'dummy timer should not be used to calculate elapsed time, '
'check if timer\'s log_level <= self._log_level.'
)
def active_time(self):
"""Returns the cumulative duration the timer has been active.
Note: Not supported for DummyTimer.
"""
raise Exception(
'active timer should not be used to calculate elapsed time, '
'check if timer\'s log_level <= self._log_level.'
)
class Timer(TimerBase): class Timer(TimerBase):
...@@ -155,7 +183,7 @@ class Timer(TimerBase): ...@@ -155,7 +183,7 @@ class Timer(TimerBase):
return _elapsed return _elapsed
def active_time(self): def active_time(self):
"""Returns the active time.""" """Calculates the cumulative duration for which the timer has been active"""
return self._active_time return self._active_time
...@@ -397,8 +425,8 @@ class Timers: ...@@ -397,8 +425,8 @@ class Timers:
reset: bool = True, reset: bool = True,
barrier: bool = False, barrier: bool = False,
): ):
"""Write timers to a tensorboard writer. Note that we only report maximum time across ranks """Write timers to a tensorboard writer.
to tensorboard. Note that we only report maximum time across ranks to tensorboard.
Args: Args:
names (List[str]): Names of the timers to log. names (List[str]): Names of the timers to log.
......
File mode changed from 100755 to 100644
...@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC): ...@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC):
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
# To support both CUDA Graphs and key value with different hidden size
self.key_hidden_size = self.hidden_size_per_attention_head
self.val_hidden_size = self.hidden_size_per_attention_head
self.core_attention = build_module( self.core_attention = build_module(
submodules.core_attention, submodules.core_attention,
config=self.config, config=self.config,
...@@ -110,6 +114,7 @@ class Attention(MegatronModule, ABC): ...@@ -110,6 +114,7 @@ class Attention(MegatronModule, ABC):
attn_mask_type=self.attn_mask_type, attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type, attention_type=self.attention_type,
cp_comm_type=cp_comm_type, cp_comm_type=cp_comm_type,
softmax_scale=self.config.softmax_scale,
) )
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
...@@ -189,6 +194,7 @@ class Attention(MegatronModule, ABC): ...@@ -189,6 +194,7 @@ class Attention(MegatronModule, ABC):
rotary_pos_emb: Tensor, rotary_pos_emb: Tensor,
rotary_pos_cos: Tensor = None, rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None, rotary_pos_sin: Tensor = None,
sequence_len_offset=None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
""" """
Saves the generated key and value tensors to the end of the buffers in inference_params. Saves the generated key and value tensors to the end of the buffers in inference_params.
...@@ -209,10 +215,10 @@ class Attention(MegatronModule, ABC): ...@@ -209,10 +215,10 @@ class Attention(MegatronModule, ABC):
inf_max_seq_length = inference_params.max_sequence_length inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype inf_max_seq_length, inf_max_batch_size, self.key_hidden_size, key.dtype
) )
inference_value_memory = self._allocate_memory( inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype inf_max_seq_length, inf_max_batch_size, self.val_hidden_size, value.dtype
) )
inference_params.key_value_memory_dict[self.layer_number] = ( inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_key_memory,
...@@ -234,7 +240,10 @@ class Attention(MegatronModule, ABC): ...@@ -234,7 +240,10 @@ class Attention(MegatronModule, ABC):
assert batch_end <= inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0) sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0), (
"Current sequence length is longer than expected maximum sequence length! "
"Increase inference_max_seq_length."
)
if self.config.flash_decode: if self.config.flash_decode:
assert ( assert (
...@@ -245,7 +254,7 @@ class Attention(MegatronModule, ABC): ...@@ -245,7 +254,7 @@ class Attention(MegatronModule, ABC):
rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end] rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end]
rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end] rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end]
rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end] rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end]
else: else: # Prefill
rotary_pos_cos_q = rotary_pos_cos[:sequence_end] rotary_pos_cos_q = rotary_pos_cos[:sequence_end]
rotary_pos_sin_q = rotary_pos_sin[:sequence_end] rotary_pos_sin_q = rotary_pos_sin[:sequence_end]
rotary_pos_cos_k = rotary_pos_cos[:sequence_end] rotary_pos_cos_k = rotary_pos_cos[:sequence_end]
...@@ -340,6 +349,7 @@ class Attention(MegatronModule, ABC): ...@@ -340,6 +349,7 @@ class Attention(MegatronModule, ABC):
rotary_pos_sin=None, rotary_pos_sin=None,
attention_bias=None, attention_bias=None,
packed_seq_params=None, packed_seq_params=None,
sequence_len_offset=None,
): ):
""" """
Perform a forward pass through the attention module. Perform a forward pass through the attention module.
...@@ -371,15 +381,15 @@ class Attention(MegatronModule, ABC): ...@@ -371,15 +381,15 @@ class Attention(MegatronModule, ABC):
if ( if (
self.config.flash_decode self.config.flash_decode
and inference_params is not None and inference_params is not None
and self.layer_number and inference_params.decode_mode
in inference_params.key_value_memory_dict # Decode phase if key already exists
): ):
assert self.layer_number in inference_params.key_value_memory_dict
assert inference_params.sequence_len_offset is not None assert inference_params.sequence_len_offset is not None
inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
self.layer_number self.layer_number
] ]
output = self.flash_decoding( output = self.flash_decoding(
sequence_len_offset=inference_params.sequence_len_offset, sequence_len_offset=sequence_len_offset,
query_layer=query, query_layer=query,
key_layer=key, key_layer=key,
value_layer=value, value_layer=value,
...@@ -394,7 +404,14 @@ class Attention(MegatronModule, ABC): ...@@ -394,7 +404,14 @@ class Attention(MegatronModule, ABC):
return output, bias return output, bias
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin inference_params,
query,
key,
value,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
) )
if packed_seq_params is not None: if packed_seq_params is not None:
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import gc
import inspect
import logging import logging
import time from collections import defaultdict
from contextlib import nullcontext
from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
import torch import torch
from torch.utils._pytree import tree_flatten
from megatron.core import parallel_state
from megatron.core.tensor_parallel.random import (
CudaRNGStatesTracker,
get_all_rng_states,
get_cuda_rng_tracker,
)
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version
try: try:
from transformer_engine.pytorch import make_graphed_callables from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.graph import restore_fp8_tensors, save_fp8_tensors
from transformer_engine.pytorch.graph import set_capture_end as te_set_capture_end
from transformer_engine.pytorch.graph import set_capture_start as te_set_capture_start
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
HAVE_TE_GRAPHS = True HAVE_TE_GRAPHS = True
except: except:
HAVE_TE_GRAPHS = False HAVE_TE_GRAPHS = False
_IS_GRAPH_CAPTURING = False
logger = logging.getLogger(__name__)
def is_graph_capturing():
"""Query if currently capturing."""
global _IS_GRAPH_CAPTURING
return _IS_GRAPH_CAPTURING
def _set_capture_start():
"""Set graph capture has started."""
global _IS_GRAPH_CAPTURING
_IS_GRAPH_CAPTURING = True
def _set_capture_end():
"""Set graph capture has ended."""
global _IS_GRAPH_CAPTURING
_IS_GRAPH_CAPTURING = False
def _check_supported_type(arg):
"""Check if arg is a supported type for cudagraph input/outputs."""
from megatron.core import InferenceParams # guard against circular import
_SUPPORTED_TYPES = {torch.Tensor, type(None), bool, int, str, float, InferenceParams}
assert type(arg) in _SUPPORTED_TYPES or is_dataclass(
arg
), f"Cudagraphs recieved an arg of type {type(arg)} which is not supported."
class _CudagraphGlobalRecord:
"""A global datastructure that records of the ordering of all _CudaGraphRunner's
first fwd or bwd passes. 'create_cudagraphs' will use this to create
cudagraphs in execution order, which is required for cudagraphs sharing a mempool."""
"""A global flag that if true, all cudagraph runners
fwd and bwd passes will be performed using their cudagraphed versions."""
cudagraph_created = False
"""A record of fwd and bwd graph creation, populated with 'record_fwd_graph' and
'record_bwd_graph."""
cudagraph_record = []
@classmethod
def record_fwd_graph(cls, runner, args, kwargs):
"""Record a fwd graph to 'cudagraph_record"""
cls.cudagraph_record.append((runner, "fwd", args, kwargs))
@classmethod
def record_bwd_graph(cls, runner):
"""Record a bwd graph to 'cudagraph_record"""
cls.cudagraph_record.append((runner, "bwd"))
@classmethod
def create_cudagraphs(cls):
"""Iterate through 'cudagraph_record' creating graphs in the order in which
they were recorded."""
# Cudagraphs have already been created, check that no cudagraphed modules ran in eager mode
if cls.cudagraph_created:
assert len(cls.cudagraph_record) == 0, (
"One or more _CudaGraphRunners requested to create a graph after cudagraphs",
"were already created!",
)
return
# No cudagraphs have been created or recorded, so do nothing
if len(cls.cudagraph_record) == 0:
return
# Otherwise, create all the recorded cudagraphs.
logging.getLogger(__name__).info(f"Creating {len(cls.cudagraph_record)} CUDA graphs")
has_te_modules = False
if HAVE_TE_GRAPHS:
for g in cls.cudagraph_record:
base_module = g[0].base_module
has_te_modules = has_te_modules or any(
[isinstance(m, TransformerEngineBaseModule) for m in base_module.modules()]
)
# If graphing only transformer layers with self attention, then apply the following
# transformer layer specific optimizations that reduce memory usage and tensor copies:
# These eventually will become unneccessary with:
# https://github.com/pytorch/pytorch/pull/137318
# 1. Some inputs to TransformerLayer (e.g. rotary_emb) are the same over all layers
# and only need to be set once.
# 2. Because the next layer consumes the previous layer's hidden states, all fwd
# cudagraphs can alternate reusing the same hidden_state input, output buffer.
# Similarly, bwd graphs can alternate the same output, input grad buffers.
optimize_transformer_layer_graph_buffers = all(
[g[0].is_transformer_decoder_layer for g in cls.cudagraph_record]
)
if optimize_transformer_layer_graph_buffers:
prev_fwd_hidden_state_output = None
prev_bwd_hidden_state_inputgrad = None
gc.collect()
torch.cuda.empty_cache()
_set_capture_start()
if has_te_modules:
te_set_capture_start()
for idx, g in enumerate(cls.cudagraph_record):
runner, graph_type = g[0:2]
if optimize_transformer_layer_graph_buffers:
if graph_type == 'fwd':
args, kwargs = g[2:]
if not runner.is_first_layer:
kwargs['hidden_states'] = prev_fwd_hidden_state_output
runner.create_fwd_graph(args, kwargs, clone_inputs=False)
class GraphStatus(Enum): # The output of TransformerLayer is: (hidden_states, None)
prev_fwd_hidden_state_output, _ = runner.fwd_graph_outputs
else:
runner.create_bwd_graph(prev_bwd_hidden_state_inputgrad)
# The first input grad TransformerLayer is for 'hidden_states'
if not runner.is_last_layer:
prev_bwd_hidden_state_inputgrad = runner.static_grad_inputs[0]
else:
runner, graph_type = g[0:2]
if graph_type == 'fwd':
args, kwargs = g[2:]
runner.create_fwd_graph(args, kwargs)
else:
runner.create_bwd_graph()
for g in cls.cudagraph_record:
runner = g[0]
runner.cudagraph_created = True
cls.cudagraph_created = True
cls.cudagraph_record = []
_set_capture_end()
if has_te_modules:
te_set_capture_end()
def create_cudagraphs():
"""Should be called at the end of each schedule function,
(e.g. forward_backward_pipelining_with_interleaving) in
`megatron.core.pipeline_parallel.schedules.py`. During the first step, _CudaGraphRunners
populate _CudagraphGlobalRecord with the global order in which cudagraphs should be created.
At the end for the first step, this function calls each runner's `create_fwd_graph` and
`create_bwd_graph` in the order recorded in _CudagraphGlobalRecord, which allows cudagraphs
to be created in execution order, which allows multiple cudagraphs to share a single
memory pool, minimizing cudagraph memory usage."""
_CudagraphGlobalRecord.create_cudagraphs()
class _GraphStatus(Enum):
"""An Enum to track if a cudagraph is ready to perform a forward or backward pass.""" """An Enum to track if a cudagraph is ready to perform a forward or backward pass."""
FWD_READY = 0 FWD_READY = 0 # Set immediately after a bwd pass
BWD_READY = 1 BWD_READY = 1 # Set immediately after a fwd pass
class GraphStatusFunc(torch.autograd.Function): class _CudagraphRecordNode(torch.autograd.Function):
"""Inserts a node into the autograd graph that tracks whether an object has an outstanding """Inserts a noop node into the autograd graph, used to record when a bwd graph needs
backward pass by toggling the value of GraphStatus. This is mainly used to detect when to create to be created."""
multiple graphs per transformer layer for pipeline parallelism.
We don't use backward module hooks as they change forward output tensors to views, see:
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook
"""
@staticmethod @staticmethod
def forward(ctx, runner, obj): def forward(ctx, runner, inputs):
"""Occurs immediately before the graph's forward pass. """Forward pass, does nothing but registers an autograd node."""
Marks the graph's backward pass as ready."""
assert (
runner.status == _GraphStatus.FWD_READY
), "Tried calling the fwd cudagraph when the bwd cudagraph was expected to be called next!"
ctx.runner = runner ctx.runner = runner
runner.status = GraphStatus.BWD_READY return inputs
return obj
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grads):
"""Occurs immediately after the graph's backward pass. """If this is the first bwd pass of this runner, record that a
Marks the graph's forward pass as ready.""" bwd graph needs to be created."""
assert ctx.runner.status == GraphStatus.BWD_READY
ctx.runner.status = GraphStatus.FWD_READY
return None, grad
class TensorDescription:
"""Records the attributes of a tensor. Used to check if a
tensor argument matches the tensor with which the module
was graph captured with."""
def __init__(self, tensor):
self.shape = tuple(tensor.shape)
self.dtype = tensor.dtype
self.device = tensor.device
def matches_tensor(self, tensor):
"""Check if 'tensor' matches the attributes of this TensorDescription."""
assert torch.is_tensor(tensor)
return (
tensor.shape == self.shape
and tensor.dtype == self.dtype
and tensor.device == self.device
)
runner = ctx.runner
assert (
runner.status == _GraphStatus.BWD_READY
), "Tried calling the bwd cudagraph when the fwd cudagraph was expected to be called next!"
class CudaGraphCallable(torch.nn.Module): runner.status = _GraphStatus.FWD_READY
"""Wraps a module to be cudagraphable, records the output of the cudagraph.
Reinserts non-tensor args, kwargs that were previously filtered out by 'get_tensor_args'.
"""
def __init__(self, module, groundtruth_args, groundtruth_kwargs): if not runner.bwd_graph_recorded:
super().__init__() _CudagraphGlobalRecord.record_bwd_graph(runner)
self.add_module('base_module', module) runner.bwd_graph_recorded = True
# The Pytorch cudagraph API requires only tensor inputs, so we strip return None, grads
# non-tensor arguments and reinsert them in forward() using these groundtruth attributes.
# We will also check future calls to the cudagraph against these to ensure the cudagraph
# is called with the same inputs as it was captured with. class _CudagraphReplayNode(torch.autograd.Function):
self.groundtruth_outputs = [] """Replays the runner's cudagraphs with autograd. Handles copying data into/out of the
self.groundtruth_args = tuple( cudagraph io and fp8 if used."""
TensorDescription(a) if torch.is_tensor(a) else a for a in groundtruth_args
) @staticmethod
self.groundtruth_kwargs = { def forward(ctx, runner, is_first_microbatch, *inputs):
k: TensorDescription(v) if torch.is_tensor(v) else v """Replay the forward graph of the passed runner."""
for k, v in groundtruth_kwargs.items()
} assert (
runner.fwd_graph is not None
def forward(self, *arg_tensors, **kwarg_tensors): ), "Tried replaying fwd cudagraph before calling 'create_fwd_cudagraph!"
"""Call the forward pass of the cudagraph. Also checks the outputs
of the cudagraph matches what the graph was traced with."""
args = list(self.groundtruth_args)
arg_tensors = list(arg_tensors)
for idx, groundtruth_arg in enumerate(self.groundtruth_args):
if isinstance(groundtruth_arg, TensorDescription):
args[idx] = arg_tensors.pop(0)
kwargs = dict(self.groundtruth_kwargs)
for k, v in self.groundtruth_kwargs.items():
if isinstance(v, TensorDescription):
kwargs[k] = kwarg_tensors[k]
# Use forward() instead of __call__ to avoid triggering hooks
out = self.base_module.forward(*args, **kwargs)
if torch.is_tensor(out):
out = tuple(out)
self.groundtruth_outputs = [TensorDescription(o) if torch.is_tensor(o) else o for o in out]
out = tuple(o for o in out if torch.is_tensor(o))
assert ( assert (
len(out) > 0 runner.status == _GraphStatus.FWD_READY
), """A graphed module returned no tensors in training mode, however the graphed module ), "Tried calling the fwd cudagraph when the bwd cudagraph was expected to be called next!"
must output at least one tensor, so that a corresponding backward node assert len(inputs) == len(
may be registered in the autograd graph.""" runner.fwd_graph_input_surface
), "Fwd cudagraph received a different number of tensors than what it was graphed with!"
if len(out) == 1: # Copy new data into fwd graph input buffer
return out[0] for user_input, cudagraph_input in zip(inputs, runner.fwd_graph_input_surface):
if user_input.data_ptr() != cudagraph_input.data_ptr():
cudagraph_input.copy_(user_input)
ctx.runner = runner
if runner.fp8_enabled:
for m in runner.base_module.modules():
if isinstance(m, TransformerEngineBaseModule):
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
if is_te_min_version("1.13.0"):
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(m.fp8_meta)
else:
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params()
)
is_first_fp8_module = FP8GlobalStateManager.is_first_fp8_module()
if is_first_fp8_module:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch)
ctx.is_first_fp8_module = is_first_fp8_module
runner.fwd_graph.replay()
# if last transformer layer, return a clone of the cudagraph output buffer, as releasing
# the cudagraph output buffer into the rest of the system may allow it to be corrupted
if runner.is_last_layer:
out = tuple(o.clone().detach() for o in runner.fwd_graph_output_surface)
else:
out = tuple(o.detach() for o in runner.fwd_graph_output_surface)
return out return out
@staticmethod
def backward(ctx, *grads):
"""Replay the backward graph of the passed runner."""
class CudaGraphRunner(torch.nn.Module): runner = ctx.runner
"""Wraps a single cudagraph and its expected arguments. Checks that assert (
the provided args are the same as what the graph was traced with. runner.bwd_graph is not None
""" ), "Tried replaying bwd cudagraph before calling 'create_bwd_cudagraph'!"
assert (
runner.status == _GraphStatus.BWD_READY
), "Tried calling the bwd cudagraph when the fwd cudagraph was expected to be called next!"
assert len(grads) == len(
runner.static_grad_outputs
), "Bwd cudagraph received a different number of tensors than what it was graphed with!"
# Copy new data into bwd graph input buffer
for user_output_grad, cudagraph_output_grad in zip(grads, runner.static_grad_outputs):
if user_output_grad.data_ptr() != cudagraph_output_grad.data_ptr():
cudagraph_output_grad.copy_(user_output_grad)
runner.bwd_graph.replay()
runner.status = _GraphStatus.FWD_READY
# Update FP8 scale factors if needed
if runner.fp8_enabled and ctx.is_first_fp8_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
# If using gradient_accumulation_fusion, whenever `main_grad` is calculated
# the `grad_added_to_main_grad` attribute is expected to set. However when using
# cudagraphs this doesn't occur so we emulate this behavior here.
for param, grad_added in runner.groundtruth_grad_added_to_main_grad.items():
param.grad_added_to_main_grad = grad_added
if runner.is_first_layer:
output_grads = tuple(
b.clone().detach() if b is not None else b for b in runner.get_input_grads()
)
else:
output_grads = tuple(
b.detach() if b is not None else b for b in runner.get_input_grads()
)
return None, None, *output_grads
class _CudaGraphRunner(torch.nn.Module):
"""Represents the execution of a cudagraphed module for a single microbatch.
If there are multiple outstanding microbatches per module, such as for pipeline parallelism,
CudaGraphManager automatically creates multiple _CudaGraphRunners per module."""
def __init__(self, base_module, fwd_mempool, bwd_mempool):
"""Creates a _CudaGraphRunner, which holds a single pair of fwd and bwd cudagraphs, which
are not created until this runner records its graph creation into
'_CudagraphGlobalRecord', and 'create_cudagraphs()' is called."""
def __init__(self, graphed_module, wrapped_module):
super().__init__() super().__init__()
self.graphed_module = graphed_module self.base_module = base_module
self.groundtruth_args = wrapped_module.groundtruth_args self.fwd_mempool = fwd_mempool
self.groundtruth_kwargs = wrapped_module.groundtruth_kwargs self.bwd_mempool = bwd_mempool
self.groundtruth_outputs = wrapped_module.groundtruth_outputs
self.status = GraphStatus.FWD_READY self.fwd_graph = None
self.bwd_graph = None
self.fwd_graph_recorded = False
self.bwd_graph_recorded = False
self.cudagraph_created = False
self.status = _GraphStatus.FWD_READY
self.fuse_wgrad_accumulation = False
self.backward_retain_grad = False
self.fp8_enabled = False
self.deallocate_pipeline_outputs = False
self.num_warmup_steps = 2
if isinstance(self.base_module.config, TransformerConfig):
self.fuse_wgrad_accumulation = self.base_module.config.gradient_accumulation_fusion
self.backward_retain_grad = self.base_module.config.cuda_graph_retain_backward_graph
self.fp8_enabled = self.base_module.config.fp8 is not None
self.deallocate_pipeline_outputs = self.base_module.config.deallocate_pipeline_outputs
self.num_warmup_steps = self.base_module.config.cuda_graph_warmup_steps
if self.fp8_enabled:
self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
from megatron.core.transformer.transformer_layer import TransformerLayer
self.is_first_layer = None
self.is_last_layer = None
self.is_transformer_decoder_layer = False
if isinstance(base_module, TransformerLayer) and isinstance(
base_module.cross_attention, IdentityOp
):
self.is_transformer_decoder_layer = True
total_num_layers = base_module.config.num_layers
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
vpp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
if vpp_size is None:
vpp_size = 1
layers_per_chunk = total_num_layers // vpp_size // pp_size
self.is_first_layer = ((base_module.layer_number - 1) % layers_per_chunk) == 0
self.is_last_layer = (base_module.layer_number % layers_per_chunk) == 0
def get_fp8_context(self):
"""Return a new fp8 context in cudagraph mode."""
if self.fp8_enabled:
return fp8_autocast(
enabled=True, calibrating=False, fp8_recipe=self.fp8_recipe, _graph=True
)
return nullcontext()
def create_fwd_graph(self, args, kwargs, clone_inputs=True):
"""Create a fwd cudagraph for this runner. Should be called inside
'create_cudagraphs()'."""
# save grads and other variables that may be affected by graph warmup
if self.training and torch.is_grad_enabled():
save_main_grads = [
param.main_grad.clone()
for param in self.base_module.parameters()
if hasattr(param, 'main_grad')
]
if self.fp8_enabled:
if is_te_min_version("1.13.0"):
saved_fp8_tensors = save_fp8_tensors([self.base_module], self.fp8_recipe)
else:
saved_fp8_tensors = save_fp8_tensors(
[self.base_module], self.fp8_recipe.amax_history_len
)
if clone_inputs:
args, kwargs = self.zero_out_tensors(args, kwargs)
self.fwd_graph_input_args = args
self.fwd_graph_input_kwargs = kwargs
input_tensors = self.get_tensors(args, kwargs)
self.fwd_graph_input_surface = input_tensors + tuple(self.base_module.parameters())
self.fwd_graph = torch.cuda.CUDAGraph()
# For cases with multiple active RNG states, e.g. TP.
for _, state in get_all_rng_states().items():
self.fwd_graph.register_generator_state(state)
# warmup again as case graph capture mode may execute a different codepath
for _ in range(self.num_warmup_steps):
with self.get_fp8_context():
outputs = self.base_module.forward(
*self.fwd_graph_input_args, **self.fwd_graph_input_kwargs
)
if self.training and torch.is_grad_enabled():
outputs = self.get_tensors(outputs)
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in self.fwd_graph_input_surface if i.requires_grad),
grad_outputs=tuple(
torch.zeros_like(o) if o.requires_grad else None for o in outputs
),
only_inputs=True,
allow_unused=True,
)
with self.get_fp8_context():
torch.cuda.synchronize()
with torch.cuda.graph(self.fwd_graph, pool=self.fwd_mempool):
outputs = self.base_module.forward(
*self.fwd_graph_input_args, **self.fwd_graph_input_kwargs
)
# save cudagraph output buffer
self.fwd_graph_outputs = outputs
self.fwd_graph_output_surface = self.get_tensors(outputs)
if self.training and torch.is_grad_enabled():
assert (
len(self.fwd_graph_output_surface) > 0
), """Tried graphing a moudule that returned no tensors in training mode,
however the graphed module must output at least one tensor,
so that a corresponding backward node may be registered in the autograd graph."""
# restore cached grads
for param in self.base_module.parameters():
if hasattr(param, 'main_grad'):
saved_grad = save_main_grads.pop(0)
assert (
param.main_grad.shape == saved_grad.shape
), "Error restoring grads while cudagraphing!"
param.main_grad.copy_(saved_grad)
if self.fp8_enabled:
restore_fp8_tensors([self.base_module], saved_fp8_tensors)
def create_bwd_graph(self, static_grad_outputs=None):
"""Create a bwd cudagraph for this runner. Should be called inside
'create_cudagraphs()'."""
self.bwd_graph = torch.cuda.CUDAGraph()
# For cases with multiple active RNG states, e.g. TP.
for _, state in get_all_rng_states().items():
self.bwd_graph.register_generator_state(state)
if static_grad_outputs is None:
static_grad_outputs = tuple(
torch.zeros_like(o) if o.requires_grad else None
for o in self.fwd_graph_output_surface
)
else:
# canoncalize as tuple
if torch.is_tensor(static_grad_outputs):
static_grad_outputs = (static_grad_outputs,)
torch.cuda.synchronize()
with torch.cuda.graph(self.bwd_graph, pool=self.bwd_mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in self.fwd_graph_output_surface if o.requires_grad),
inputs=tuple(i for i in self.fwd_graph_input_surface if i.requires_grad),
grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
retain_graph=self.backward_retain_grad,
only_inputs=True,
allow_unused=True,
)
# Constructs a tuple suitable for returning from Graphed.backward:
# Pads out the actually-needed grads with Nones in gradient slots for inputs
# that don't require grad. I couldn't think of a one-liner for this pattern.
static_grad_inputs = []
grad_idx = 0
for arg in self.fwd_graph_input_surface:
has_wgrad_fusion = self.fuse_wgrad_accumulation and getattr(
arg, "grad_added_to_main_grad", False
)
if arg.requires_grad:
if has_wgrad_fusion:
static_grad_inputs.append(None)
else:
static_grad_inputs.append(grad_inputs[grad_idx])
grad_idx += 1
else:
static_grad_inputs.append(None)
self.groundtruth_grad_added_to_main_grad = {}
if self.fuse_wgrad_accumulation:
for param in self.base_module.parameters():
if hasattr(param, "grad_added_to_main_grad"):
self.groundtruth_grad_added_to_main_grad[param] = param.grad_added_to_main_grad
self.static_grad_outputs = static_grad_outputs
self.static_grad_inputs = static_grad_inputs
def get_input_grads(self):
"""Get the inputs grads that are returned by the bwd cudagraph call. If using grad accum
fusion, wgrads have already been accumulated, so return dummy wgrads."""
if not self.fuse_wgrad_accumulation:
return self.static_grad_inputs
else:
num_dgrads = len(self.static_grad_inputs) - len(list(self.base_module.parameters()))
dgrads = self.static_grad_inputs[:num_dgrads]
wgrads = self.static_grad_inputs[num_dgrads:]
wgrads_with_placeholders = []
for idx, param in enumerate(self.base_module.parameters()):
if getattr(param, "grad_added_to_main_grad", False):
if getattr(param, "zero_out_wgrad", False):
wgrad = torch.zeros(
param.main_grad.shape,
dtype=param.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
param.main_grad.shape,
dtype=param.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = wgrads[idx]
wgrads_with_placeholders.append(wgrad)
return tuple(dgrads + wgrads_with_placeholders)
def record_graph_capture(self, args, kwargs):
"""Records the data needed to create this runner's forward cudagraph.
The first pass records a graph and appends the runner to _CudagraphGlobalRecord.
The actual cudagraph will be created when 'create_cudagraphs()` is called. Subsequent
passes should replay the graph."""
if not self.fwd_graph_recorded:
logger.debug(f"Recording forward graph creation...")
if not self.is_first_layer:
# transformer layers hidden_states are already saved as the output of the previous
# layer's cudagraph so avoid saving again
kwargs_copy = dict(kwargs)
kwargs_copy['hidden_states'] = None
_CudagraphGlobalRecord.record_fwd_graph(self, args, kwargs_copy)
else:
_CudagraphGlobalRecord.record_fwd_graph(self, args, kwargs)
self.fwd_graph_recorded = True
# Run the forward pass as normal in eager mode.
out = super(MegatronModule, self.base_module).__call__(*args, **kwargs)
# Register a noop autograd node that toggles `self.graph_status` in the bwd pass, which
# tracks when the runner completes its bwd pass.
# If it's the first bwd encountered by this runner, record it to _CudagraphGlobalRecord
out = tuple(_CudagraphRecordNode.apply(self, o) if torch.is_tensor(o) else o for o in out)
# autograd nodes return inputs as views, so clone the tensor as returning views may cause
# issues, for instance with pipeline parallelism
return tuple(o.clone() if torch.is_tensor(o) else o for o in out)
def replay_graph_capture(self, is_first_microbatch, args, kwargs):
"""Replay the fwd cuda graph with autograd."""
assert self.matches_graph_inputs(
args, kwargs
), "Tried replaying a cudagraph with different arguments than what if was created with!"
def static_args_match(self, args, kwargs): inp_tensors = self.get_tensors(args, kwargs)
"""Check the the passed args, kwargs match with the arg, kwargs func_args = inp_tensors + tuple(self.parameters())
out = _CudagraphReplayNode.apply(self, is_first_microbatch, *func_args)
out = list(out)
return tuple(out.pop(0) if torch.is_tensor(o) else o for o in self.fwd_graph_outputs)
def matches_graph_inputs(self, args, kwargs):
"""Check that the passed args, kwargs match with the arg, kwargs
the graph was created with.""" the graph was created with."""
def check(val, ref): def check(val, ref):
if isinstance(ref, TensorDescription):
return ref.matches_tensor(val)
return ref == val
if len(args) != len(self.groundtruth_args): _check_supported_type(val)
_check_supported_type(ref)
# check that the args are the same type
if not ((type(val) == type(ref)) or (is_dataclass(val) and is_dataclass(ref))):
return False
# if tensors, check they have the same shape, device and type
# differing memory layout is allowed as 'copy_' is able interop different layouts
if isinstance(ref, torch.Tensor):
return (
val.shape == ref.shape and val.dtype == ref.dtype and val.device == ref.device
)
# if dataclass, check args in fields are the same
elif is_dataclass(ref):
for field in fields(ref):
if not check(getattr(val, field.name), getattr(ref, field.name)):
return False
return True
else:
return ref == val
if len(args) != len(self.fwd_graph_input_args):
return False return False
for idx, groundtruth_arg in enumerate(self.groundtruth_args): for arg, graph_arg in zip(args, self.fwd_graph_input_args):
if not check(args[idx], groundtruth_arg): if not check(args, graph_arg):
return False return False
if kwargs.keys() != self.groundtruth_kwargs.keys(): if kwargs.keys() != self.fwd_graph_input_kwargs.keys():
return False return False
for k, v in self.groundtruth_kwargs.items(): for k, v in self.fwd_graph_input_kwargs.items():
if not check(kwargs[k], v): if not check(kwargs[k], v):
return False return False
return True return True
def forward(self, args, kwargs, is_first_microbatch=None): def zero_out_tensors(self, args, kwargs=None):
"""Call the forward pass of the cuda graph.""" """Replace all tensors inside arg, kwargs with zeroed copies."""
if self.training and torch.is_grad_enabled():
args = list(args) def clone_tensor(ten):
for pos in range(len(args)): cloned = torch.zeros_like(ten)
if torch.is_tensor(args[pos]): cloned.requires_grad = ten.requires_grad
args[pos] = GraphStatusFunc.apply(self, args[pos]) return cloned
for k, v in kwargs.items():
if torch.is_tensor(v): def process_arg(arg):
kwargs[k] = GraphStatusFunc.apply(self, v) _check_supported_type(arg)
if torch.is_tensor(arg):
ret_tensors = self.graphed_module(is_first_microbatch=is_first_microbatch, *args, **kwargs) return clone_tensor(arg)
ret_tensors = [ret_tensors] if torch.is_tensor(ret_tensors) else list(ret_tensors) elif is_dataclass(arg):
out = tuple( for field in fields(arg):
ret_tensors.pop(0) if isinstance(o, TensorDescription) else o attr = getattr(arg, field.name)
for o in self.groundtruth_outputs if torch.is_tensor(attr):
) setattr(arg, field.name, clone_tensor(attr))
return arg
# Check that the static graph matches what was recorded during graph capture
assert len(out) == len(self.groundtruth_outputs) args_replaced = []
for idx, o in enumerate(self.groundtruth_outputs): for arg in args:
if isinstance(o, TensorDescription): args_replaced.append(process_arg(arg))
assert o.matches_tensor(out[idx]) if kwargs is None:
return args_replaced
kwargs_replaced = {}
for k, v in kwargs.items():
kwargs_replaced[k] = process_arg(v)
return args_replaced, kwargs_replaced
def get_tensors(self, args, kwargs=None):
"""Filter and flatten all tensors from args and kwargs."""
def extract_tensors(arg):
_check_supported_type(arg)
if torch.is_tensor(arg):
return [arg]
elif is_dataclass(arg):
tens = []
for field in fields(arg):
attr = getattr(arg, field.name)
if torch.is_tensor(attr):
tens.append(attr)
return tens
else: else:
assert o == out[idx] return []
if len(out) == 1: tens = []
return out[0] args, _ = tree_flatten(args)
return out for a in args:
tens.extend(extract_tensors(a))
if kwargs is not None:
kwargs, _ = tree_flatten(kwargs)
for k in kwargs:
tens.extend(extract_tensors(k))
return tuple(tens)
class CudaGraphManager(torch.nn.Module): class CudaGraphManager(torch.nn.Module):
"""Creates and runs cudagraphs for a megatron module.""" """Creates and runs cudagraphs for a megatron module"""
"""A global mempool for when 'cuda_graph_use_single_mempool' is used."""
global_mempool = None
def __init__(self): """Forward pass mempools, used with cudagraph reuse mode."""
fwd_mempools = None
"""Backward pass mempool, used with cudagraph reuse mode."""
bwd_mempool = None
def __init__(self, config: TransformerConfig):
super().__init__() super().__init__()
rng_tracker = get_cuda_rng_tracker()
assert (HAVE_TE_GRAPHS and isinstance(rng_tracker, TECudaRNGStatesTracker)) or (
isinstance(rng_tracker, CudaRNGStatesTracker) and rng_tracker.use_cudagraphable_rng
), "RNG tracker does not support cudagraphs!"
self.cudagraph_runners = [] self.cudagraph_runners = []
self.is_first_microbatch = True self.is_first_microbatch = False
assert HAVE_TE_GRAPHS, "CudaGraphManager currently requires TransformerEngine"
# Without pipeline parallelism, microbatches execute one at a time.
# Therefore modules will always execute in the same order, so cudagraphs
# can both be reused and share a single mempool.
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
self.reuse_cudagraphs = True
self.use_single_mempool = True
else:
if config.cuda_graph_use_single_mempool:
self.reuse_cudagraphs = False
self.use_single_mempool = True
else:
self.reuse_cudagraphs = True
self.use_single_mempool = False
# Mempools are static so that multiple cudagraph managers may share the same mempool
if self.use_single_mempool:
if CudaGraphManager.global_mempool is None:
CudaGraphManager.global_mempool = torch.cuda.graph_pool_handle()
else:
# All cudagraphs in the same microbatch use the same mempool. For pipeline parallelism,
# additonally all bwd passes share the same mempool
if CudaGraphManager.fwd_mempools is None:
CudaGraphManager.fwd_mempools = defaultdict(
lambda: defaultdict(torch.cuda.graph_pool_handle)
)
CudaGraphManager.bwd_mempool = torch.cuda.graph_pool_handle()
# Cudagraph stream capture requires no operations on the default stream prior to the # Cudagraph stream capture requires no operations on the default stream prior to the
# capture, so change to a side stream. At graph capture change it back. # capture, so change to a side stream.
self.stream = torch.cuda.current_stream() self.stream = torch.cuda.current_stream()
torch.cuda.set_stream(torch.cuda.Stream()) torch.cuda.set_stream(torch.cuda.Stream())
def call_ddp_preforward_hook(self, module):
"""Call any DDP pre-forward hooks which are used to launch async data parallel
param gather. Any other pre-forward hooks are not allowed."""
from megatron.core.distributed import distributed_data_parallel
if module._forward_pre_hooks:
for _, hook in module._forward_pre_hooks.items():
assert (
inspect.getmodule(hook) == distributed_data_parallel
), "Tried to cudagraph a module with user registered pre-forward hooks, \
which is not allowed."
# Only hooks from Mcore DDP, which take no args, should be called at this point.
hook(module)
def get_cudagraph_runner(self, megatron_module):
'''Returns a valid cudagraph runner for the current forward call.
For single mempool mode, we create a cudagraph for each call, if the module is called
multiple times per step, for instance in the case of pipeline parallelism.
The cudagraph corresponding to this call is the first element of 'self.cudagraph_runners'.
We iterate through the list by 1 for each call, and the number of calls is equal to the
length of 'self.cudagraph_runners'.
Otherwise, we assign a mempool per microbatch, which allows cudagraphs to be reused
over different microbatches by tracking their respective fwd and bwd passes.'''
if self.use_single_mempool:
fwd_mempool = CudaGraphManager.global_mempool
bwd_mempool = CudaGraphManager.global_mempool
else:
vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vpp_rank = 0 if vpp_rank is None else vpp_rank
fwd_mempool = CudaGraphManager.fwd_mempools[vpp_rank][len(self.cudagraph_runners)]
bwd_mempool = CudaGraphManager.bwd_mempool
if self.reuse_cudagraphs:
runner = next(
(r for r in self.cudagraph_runners if r.status == _GraphStatus.FWD_READY), None
)
if runner is None:
if _CudagraphGlobalRecord.cudagraph_created:
assert False
else:
runner = _CudaGraphRunner(megatron_module, fwd_mempool, bwd_mempool)
self.cudagraph_runners.append(runner)
else:
# Create cudagraphs for every microbatch
if _CudagraphGlobalRecord.cudagraph_created:
runner = self.cudagraph_runners[0]
assert runner.status == _GraphStatus.FWD_READY
self.cudagraph_runners = self.cudagraph_runners[1:] + self.cudagraph_runners[:1]
else:
runner = _CudaGraphRunner(megatron_module, fwd_mempool, bwd_mempool)
self.cudagraph_runners.append(runner)
return runner
def __call__(self, megatron_module, args, kwargs): def __call__(self, megatron_module, args, kwargs):
"""Calls the forward pass of the cudagraphed module. """Calls the forward pass of the cudagraphed module.
...@@ -219,95 +845,46 @@ class CudaGraphManager(torch.nn.Module): ...@@ -219,95 +845,46 @@ class CudaGraphManager(torch.nn.Module):
""" """
# param.data_ptr() below is used to trigger any hooks that have attached to the parameter. if _CudagraphGlobalRecord.cudagraph_created:
# Specifically, this is trying to trigger the param sync hook for the APEX optimizer, which
# triggers param syncs by hooking into any param references.
# However cudagraphs disables this, so we workaround by manually referencing params here.
# For more information see:
# https://github.com/NVIDIA/apex/blob/7001836/apex/contrib/optimizers/distributed_fused_adam.py#L885C9
for param in megatron_module.parameters():
param.data_ptr()
runner = None
for _runner in self.cudagraph_runners:
if _runner.static_args_match(args, kwargs) and _runner.status == GraphStatus.FWD_READY:
runner = _runner
break
if runner is None:
if self.training and torch.is_grad_enabled(): if self.training and torch.is_grad_enabled():
runner = self.create_cudagraph_module(megatron_module, args, kwargs) # param.data_ptr() below is used to trigger any hooks that have attached to the
self.cudagraph_runners.append(runner) # parameter. Specifically, this is trying to trigger the param sync hook for the
logging.getLogger(__name__).info( # APEX optimizer, which triggers param syncs by hooking into any param references.
f"Creating cudagraph; now have {len(self.cudagraph_runners)}" # However cudagraphs disables this, so we workaround by manually referencing
) # params here. For more information see:
# https://github.com/NVIDIA/apex/blob/7001836/apex/contrib/optimizers/distributed_fused_adam.py#L885C9
for param in megatron_module.parameters():
param.data_ptr()
# Trigger Mcore DDP pre-forward hooks
self.call_ddp_preforward_hook(megatron_module)
for module in megatron_module.modules():
self.call_ddp_preforward_hook(module)
runner = self.get_cudagraph_runner(megatron_module)
out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs)
else:
if 'inference_params' in kwargs.keys() and kwargs['inference_params']:
# Inference generation mode
runner = self.get_cudagraph_runner(megatron_module)
runner.eval()
out = runner.record_graph_capture(args, kwargs)
elif self.training and torch.is_grad_enabled():
# Training mode
runner = self.get_cudagraph_runner(megatron_module)
out = runner.record_graph_capture(args, kwargs)
else: else:
# No cudagraphs were found in inference mode, so fallback to eager since # No cudagraphs were found in training mode with grad disabled, so fallback to
# tensor.requires_grad is needed to correctly trace the backward graph. # eager since autograd is needed to correctly trace the backward graph.
return super(MegatronModule, megatron_module).__call__(*args, **kwargs) return super(MegatronModule, megatron_module).__call__(*args, **kwargs)
tensor_args, tensor_kwargs = self.get_tensor_args(args, kwargs) runner = self.get_cudagraph_runner(megatron_module)
out = runner(tensor_args, tensor_kwargs, is_first_microbatch=self.is_first_microbatch) out = runner.record_graph_capture(args, kwargs)
self.is_first_microbatch = False
return out
def get_tensor_args(self, args, kwargs):
"""Filter out non-tensor arguments from args and kwargs.
Needed since 'make_graphed_callables' expects Torch.tensor arg, kwargs."""
tensor_kwargs = {}
for k, v in kwargs.items():
if torch.is_tensor(v):
tensor_kwargs[k] = v
tensor_args = tuple(arg for arg in args if torch.is_tensor(arg))
return tensor_args, tensor_kwargs
def create_cudagraph_module(self, megatron_module, args, kwargs):
"""Record the graph capture stream. Runs warmup iterations of
megatron_module, and creates a autograd function, where the
forward, backward functions are the cudagraphs of module's forward,
backward passes. Finally wraps this cudagraph function with a CudaGraphRunner.
"""
torch.cuda.synchronize()
torch.cuda.set_stream(self.stream)
start = time.time()
wrapped_module = CudaGraphCallable(megatron_module, args, kwargs)
sample_args, sample_kwargs = self.get_tensor_args(args, kwargs)
# Cudagraphs require no autograd history recorded on sample inputs
sample_args_detached = tuple(n.detach() for n in sample_args)
sample_kwargs_detached = {k: v.detach() for k, v in sample_kwargs.items()}
sample_args_copy = tuple(torch.clone(n) for n in sample_args_detached)
sample_kwargs_copy = {k: torch.clone(v) for k, v in sample_kwargs_detached.items()}
# Zero out input args inplace so cudagraph warmup doesnt affect grads
for orig, detach in zip(sample_args, sample_args_detached):
detach.zero_()
detach.requires_grad = orig.requires_grad
for k, detach in sample_kwargs_detached.items():
detach.zero_()
detach.requires_grad = sample_kwargs[k].requires_grad
fp8_enabled = megatron_module.config.fp8 is not None
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_enabled else None
graphed_module = make_graphed_callables(
modules=wrapped_module,
sample_args=sample_args_detached,
sample_kwargs=sample_kwargs_detached,
_order=[1, -1],
allow_unused_input=True,
fp8_enabled=fp8_enabled,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
)
# Restore zeroed out sample args # If forward only, next replay should be a forward pass as well
# Detach again since pytorch prohibits inplace ops on leaf nodes if self.training and torch.is_grad_enabled():
for orig, copy in zip(sample_args, sample_args_copy): runner.status = _GraphStatus.BWD_READY
orig.detach().copy_(copy) else:
for k, orig in sample_kwargs.items(): runner.status = _GraphStatus.FWD_READY
orig.detach().copy_(sample_kwargs_copy[k])
logging.getLogger(__name__).info(f'Time spent in cudagraph capture: {time.time() - start}s') return out
return CudaGraphRunner(graphed_module, wrapped_module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment