Commit 4b097dee authored by liangjing's avatar liangjing
Browse files

update to core_v0.9

parent 3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import logging
import os
from functools import partial
from logging import getLogger
from pathlib import Path
from typing import List
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import zarr
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
from .base import (
LoadShardedStrategy,
SaveShardedStrategy,
StrategyAction,
register_default_strategy,
)
logger = logging.getLogger(__name__)
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
np.dtype('bool'): torch.bool,
np.dtype('uint8'): torch.uint8,
np.dtype('int8'): torch.int8,
np.dtype('int16'): torch.int16,
np.dtype('int32'): torch.int32,
np.dtype('int64'): torch.int64,
np.dtype('float16'): torch.float16,
np.dtype('float32'): torch.float32,
np.dtype('float64'): torch.float64,
np.dtype('complex64'): torch.complex64,
np.dtype('complex128'): torch.complex128,
}
torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()}
try:
import tensorstore
# Register a bfloat16 type with this import
import tensorstore # pylint: disable=unused-import
HAS_BFLOAT16 = True
numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16
......@@ -41,11 +51,28 @@ try:
except ImportError:
HAS_BFLOAT16 = False
_import_trigger = None
logger = getLogger(__name__)
def register_default_zarr_strategies():
"""Register default strategies related to Zarr backend."""
register_default_strategy(
StrategyAction.SAVE_SHARDED, 'zarr', 1, ZarrSaveShardedStrategy('zarr', 1)
)
class ZarrSaveShardedStrategy(SaveShardedStrategy):
def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
"""Save strategy for Zarr backend."""
def __init__(self, backend: str, version: int):
super().__init__(backend, version)
logger.warning(
f'`zarr` distributed checkpoint backend is deprecated.'
' Please switch to PyTorch Distributed format (`torch_dist`).'
)
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
sharded_tensors = list(nested_values(sharded_state_dict))
arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)
for ten, arr in zip(sharded_tensors, arrays):
_save_to_existing_array(ten, arr)
......@@ -54,24 +81,41 @@ class ZarrSaveShardedStrategy(SaveShardedStrategy):
def _create_or_open_zarr_arrays(
sharded_tensors: List[ShardedTensor], checkpoint_dir: Path
) -> List[zarr.Array]:
) -> List[Optional[zarr.Array]]:
"""Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk,
opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank
that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays = []
for ten in sharded_tensors:
if _should_create_array(ten):
_create_zarr_array(ten, checkpoint_dir)
# TODO: maybe reuse the opened arrays
arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None
arrays.append(arr)
torch.distributed.barrier()
for ten in sharded_tensors:
# if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}:
# continue
# Open arrays created above by other processes
for arr_idx, ten in enumerate(sharded_tensors):
if arrays[arr_idx] is not None:
# array created by this process
assert _should_create_array(ten), ten
continue
if not is_main_replica(ten.replica_id):
# this array won't be needed for saving and can stay None
continue
open_kwargs = {}
if ten.flattened_range is not None:
open_kwargs['synchronizer'] = zarr.ProcessSynchronizer(
str(checkpoint_dir / f'{ten.key}.sync')
)
arr = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs)
arrays.append(arr)
arrays[arr_idx] = _open_zarr_array_verbose(checkpoint_dir / ten.key, 'r+', **open_kwargs)
return arrays
......@@ -83,9 +127,10 @@ def _should_create_array(ten: ShardedTensor):
)
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: zarr.Array):
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]):
if not is_main_replica(sharded_tensor.replica_id):
return
assert arr is not None
x = sharded_tensor.data
x = x.detach().cpu()
torch.cuda.synchronize()
......@@ -114,6 +159,7 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
fill_value=None,
write_empty_chunks=True,
)
logger.debug(f'Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}')
except zarr.errors.ContainsArrayError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} already exists'
......@@ -127,12 +173,21 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
class ZarrLoadShardedStrategy(LoadShardedStrategy):
"""Load strategy for the Zarr backend."""
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
dict_list_map_inplace(
partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict
)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_zarr_shape_dtype(path):
arr = zarr.open(path, 'r')
return arr.shape, arr.dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_zarr_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
......@@ -142,12 +197,7 @@ class ZarrLoadShardedStrategy(LoadShardedStrategy):
def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
try:
arr = zarr.open(checkpoint_dir / sharded_tensor.key, 'r')
except zarr.errors.PathNotFoundError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} not found'
) from e
arr = _open_zarr_array_verbose(checkpoint_dir / sharded_tensor.key, 'r')
if not sharded_tensor.allow_shape_mismatch and sharded_tensor.global_shape != arr.shape:
_msg = (
......@@ -161,7 +211,22 @@ def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
return postprocess_numpy_array(x, sharded_tensor)
def _open_zarr_array_verbose(path: Path, mode: str, **open_kwargs):
try:
return zarr.open(str(path), mode, **open_kwargs)
except zarr.errors.PathNotFoundError as e:
ckpt_dir = path.parent
err_msg = f'Array {path} not found'
if ckpt_dir.exists():
ckpt_files = [f.name for f in ckpt_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory {ckpt_dir} content: {ckpt_files}')
else:
err_msg += f'. Checkpoint directory {ckpt_dir} does not exist.'
raise CheckpointingException(err_msg) from e
def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True):
"""Turn numpy array to torch tensor."""
x = loaded_array
if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'):
x = x.astype(np.dtype('float32'))
......@@ -189,10 +254,12 @@ def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=
def flatten_range(sharded_tensor, x):
"""Apply flattened range to a tensor."""
return x.flatten()[sharded_tensor.flattened_range]
def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
"""Pad tensor to the expected shape."""
pad_args = []
assert len(x.shape) == len(expected_sharded_ten.local_shape)
# Reversed iteration order because F.pad expects so
......@@ -204,9 +271,10 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
if x_sh == exp_sh:
pad_args.extend((0, 0))
elif x_sh > exp_sh:
assert (
False
), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}'
assert False, (
f'Expected shape ({exp_sh}) smaller than actual ({x_sh})'
f' for {repr(expected_sharded_ten)}'
)
else:
pad_args.extend((0, exp_sh - x_sh))
# TODO: behavior control with envvar is for testing purposes only, remove it
......@@ -224,7 +292,30 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0)
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy(
'zarr', 1
)
def load_zarr_based_sharded_metadata(
checkpoint_dir: Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], np.dtype]]
) -> ShardedStateDict:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir() or not (subdir / '.zarray').exists():
continue
key = subdir.name
arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
sharded_state_dict[key] = ShardedTensor(
key,
None,
numpy_to_torch_dtype_dict[arr_dtype],
arr_shape,
arr_shape,
tuple(0 for _ in arr_shape),
tuple(1 for _ in arr_shape),
)
return sharded_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
""" Helpers for manipulating sharded tensors and sharded state dicts. """
from typing import Dict, Optional, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
LocalNonpersitentObject,
LocalNonpersistentObject,
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId = Tuple[str, tuple, Optional[tuple]]
def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId:
"""Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range = sharded_tensor.flattened_range
return (
sharded_tensor.key,
sharded_tensor.global_offset,
None if f_range is None else (f_range.start, f_range.stop),
)
def _sharded_object_id(sharded_object: ShardedObject) -> _ShardId:
"""Unique id of the sharded object data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_object (ShardedObject): sharded object representing the data shard
Returns (tuple): unique id of a data shard
"""
return (sharded_object.key, sharded_object.global_offset, sharded_object.global_shape)
def extract_sharded_tensors(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedTensor objects
from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor
(keeping the original state dict structure)
"""
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor))
def extract_sharded_tensors_and_factories(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects
from a given state dict with any objects.
Args:
sharded_state_dict:
state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory))
)
......@@ -29,16 +93,127 @@ def extract_sharded_tensors_and_factories(
def extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict,
lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)),
lambda v: isinstance(v, (ShardedTensor, LocalNonpersistentObject, ShardedTensorFactory)),
)
def extract_sharded_base(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedBase from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedBase objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedBase objects (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase))
def extract_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
Args:
sharded_state_dict: state dict possibly containing LocalNonpersistentObjects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all LocalNonpersistentObjects
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject)
)
def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str):
"""Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def add_prefix(t):
if isinstance(t, ShardedTensor):
t.key = f'{prefix}.{t.key}'
if isinstance(t, ShardedBase):
t.key = f'{prefix}{t.key}'
return t
dict_list_map_inplace(add_prefix, sharded_state_dict)
def replace_prefix_for_sharding(
sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str
):
"""Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def _replace_prefix(x):
if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
if not x.key.startswith(old_prefix):
raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}')
x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
return x
dict_list_map_inplace(_replace_prefix, sharded_state_dict)
def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]):
"""Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]):
map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def _replace_prefixes(x):
if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
return x
for old_prefix, new_prefix in prefix_map.items():
if x.key.startswith(old_prefix):
x.key = (
f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
)
break
return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from collections import Counter, defaultdict
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import (
extract_matching_values,
map_reduce,
nested_values,
)
from megatron.core.dist_checkpointing.mapping import (
ShardedBase,
ShardedObject,
ShardedStateDict,
is_main_replica,
)
from megatron.core.dist_checkpointing.strategies.base import (
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
if TYPE_CHECKING:
from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata
logger = logging.getLogger(__name__)
# list of local saved/loaded ShardedBase objects
_LocalMetadata = List[Union[ShardedTensor, ShardedObject]]
# list of lists of global saved/loaded ShardedBase objects (each list element corresponds to global rank)
_GlobalMetadata = List[_LocalMetadata]
class StrictHandling(Enum):
"""Determines handling of load mismatch (non-empty "unexpected" or "missing" keys).
Different flags carry different implications on performance and behaviour and
are divided into two groups:
- *_UNEXPECTED
- *_ALL
The first group ignores missing keys (present in the checkpoint but missing
in the sharded state dict) which is created in order to avoid inter-rank
metadata exchange. Note that the metadata exchange will happen anyway
with `load(..., validate_access_integrity=True)` flag in which case using the
`*_ALL` option is recommended as it provides a more thorough check with no
performance penalty wrt. `*_UNEXPECTED` group.
All options except for the first one (`ASSUME_OK_UNEXPECTED`) require
extra disk access before the load in order to remove unexpected keys
from the sharded state dict requested to load.
"""
# Relies on the underlying strategy to raise error on unexpected keys
ASSUME_OK_UNEXPECTED = 'assume_ok_unexpected'
# Logs (with WARNING level) "unexpected" keys. Missing keys are ignored.
# This is treated as a reasonable default for a "non-strict" load
LOG_UNEXPECTED = 'log_unexpected'
# Logs (with WARNING level) all mismatched keys.
LOG_ALL = 'log_all'
# Raise error on unexpected keys before load attempt.
# Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires
# extra disk access.
RAISE_UNEXPECTED = 'raise_unexpected'
# Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires
# metadata exchange.
RAISE_ALL = 'raise_all'
# "Unexpected" mismatches are not reported, but returned by the `load`
# function along with the loaded state dict. Missing keys are ignored.
RETURN_UNEXPECTED = 'return_unexpected'
# All mismatches are returned along with the loaded state dict.
RETURN_ALL = 'return_all'
# Simply ignores mismatches (not recommended)
IGNORE_ALL = 'ignore_all'
@staticmethod
def requires_explicit_ckpt_mismatch_check(val: 'StrictHandling') -> bool:
"""Whether a given strict flag involves mismatch check against the checkpoint."""
return val != StrictHandling.ASSUME_OK_UNEXPECTED
@staticmethod
def requires_global_app_metadata(val: 'StrictHandling') -> bool:
"""Whether a given strict option requires global metadata for validation."""
return val in (
StrictHandling.IGNORE_ALL,
StrictHandling.RAISE_ALL,
StrictHandling.RETURN_ALL,
StrictHandling.LOG_ALL,
)
@staticmethod
def requires_returning_mismatch_keys(val: 'StrictHandling') -> bool:
"""Whether a given strict option results in extra return value from the `load` function."""
return val in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL)
def parse_strict_flag(strict: Union[str, StrictHandling]) -> StrictHandling:
"""Parse user passed strict flag from a string to StrictHandling instance.
Args:
strict (str, StrictHandling): strict flag to parse. If already an instance
of StrictHandling, this function is a noop.
Returns:
StrictHandling: enum instance
"""
if isinstance(strict, StrictHandling):
return strict
try:
return StrictHandling(strict)
except (ValueError, TypeError) as e:
raise ValueError(f'Invalid strict flag: {e}') from e
def validate_integrity_and_strict_load(
sharded_state_dict: ShardedStateDict,
strict: StrictHandling,
validate_access_integrity: bool,
local_metadata: Optional[_LocalMetadata] = None,
global_metadata: Optional[_GlobalMetadata] = None,
ckpt_sharded_metadata: Optional['CkptShardedMetadata'] = None,
) -> Tuple[ShardedStateDict, Set[str], Set[str]]:
"""Validates sharding integrity and potential mismatches with the checkpoint.
`validate_access_integrity` controls sharding integrity check (orthogonal
to strictness checking) which verifies `sharded_state_dict` runtime completeness
(in isolation from the actual checkpoint).
`strict` flag controls handling of mismatches between the requested
sharded state dict to load and the actual checkpoint. See `StrictHandling`
docs for details regarding flag behavior and performance implications
(disk interactions or inter-rank communication).
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to verify.
strict (StrictHandling): flag determining how to handle sharded keys mismatch.
validate_access_integrity (bool): whether to perform sharding validation.
local_metadata (_LocalMetadata, optional): local sharded state dict metadata.
Defaults to None, in which case it's determined based on `sharded_state_dict`.
global_metadata (_GlobalMetadata, optional): global sharded state dict metadata
(exchanged between ranks). Defaults to None, in which case "missing"
keys are not determined.
ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata
from the checkpoint. Defaults to None, which only makes sense
for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value.
Returns:
Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict
without unexpected keys, missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. Additionally,
missing keys might be erroneously empty (depending on `strict` value).
"""
missing_keys, unexpected_keys = [], []
if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
if ckpt_sharded_metadata is None:
raise CheckpointingException(
'Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None.'
)
if local_metadata is None:
local_metadata = [
sh_base.without_data() for sh_base in nested_values(sharded_state_dict)
]
# We don't want to check for missing keys even if we could
_skip_missing_keys = strict in (
StrictHandling.ASSUME_OK_UNEXPECTED,
StrictHandling.LOG_UNEXPECTED,
StrictHandling.RAISE_UNEXPECTED,
StrictHandling.RETURN_UNEXPECTED,
)
missing_keys, unexpected_keys = _determine_missing_and_unexpected_keys(
ckpt_sharded_metadata, local_metadata, None if _skip_missing_keys else global_metadata
)
sharded_state_dict = adjust_non_strict_load(sharded_state_dict, unexpected_keys)
if strict == StrictHandling.IGNORE_ALL:
missing_keys, unexpected_keys = [], []
elif strict in (StrictHandling.RAISE_UNEXPECTED, StrictHandling.RAISE_ALL):
maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, True)
elif strict in (StrictHandling.LOG_UNEXPECTED, StrictHandling.LOG_ALL):
maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, False)
if validate_access_integrity:
if global_metadata is None:
raise CheckpointingException(
'Cannot check sharding intergrity without global_metadata (None).'
)
validate_sharding_integrity(global_metadata)
return sharded_state_dict, missing_keys, unexpected_keys
def verify_checkpoint_and_load_strategy(
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]:
"""Verifies if checkpoint metadata exists and matches given strategies.
If no strategies are passed, they are determined based on the checkpoint metadata.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified
if compatible with the checkpoint content. If None, the default sharded load strategy
for the checkpoint backend will be returned.
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
if compatible with the checkpoint content. If None, the default common load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
if sharded_strategy is None:
sharded_strategy = get_default_strategy(
StrategyAction.LOAD_SHARDED,
saved_config.sharded_backend,
saved_config.sharded_backend_version,
)
elif isinstance(sharded_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)
if common_strategy is None:
common_strategy = get_default_strategy(
StrategyAction.LOAD_COMMON,
saved_config.common_backend,
saved_config.common_backend_version,
)
elif isinstance(common_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_COMMON, *common_strategy)
sharded_strategy.check_backend_compatibility(saved_config.sharded_backend)
sharded_strategy.check_version_compatibility(saved_config.sharded_backend_version)
common_strategy.check_backend_compatibility(saved_config.common_backend)
common_strategy.check_version_compatibility(saved_config.common_backend_version)
return sharded_strategy, common_strategy
def adjust_non_strict_load(
sharded_state_dict: ShardedStateDict, sharded_keys_to_remove: Set[str]
) -> ShardedStateDict:
"""Adjusts sharded state dict removing keys not existing in the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to modify
sharded_keys_to_remove (Set[str]): keys to remove from the state dict
Returns:
ShardedStateDict: state dict without ShardedBase objects with specified keys
"""
def is_unexpected_key(x: ShardedBase):
assert isinstance(x, ShardedBase), f'Unexpected type {type(x)}'
return x.key in sharded_keys_to_remove
_, sharded_state_dict = extract_matching_values(sharded_state_dict, is_unexpected_key)
return sharded_state_dict
def _determine_missing_and_unexpected_keys(
ckpt_sharded_metadata: 'CkptShardedMetadata',
local_metadata: _LocalMetadata,
global_metadata: Optional[_GlobalMetadata] = None,
) -> Tuple[Set[str], Set[str]]:
"""Determines load mismatches based on metadata.
There is an asymmetry between "unexpected" and "missing" keys.
Unexpected keys can be determined based only on local metadata.
Missing keys must be based on global metadata, since other ranks might access
different keys than the current rank.
In consequence, the return value of this function is different on each rank:
"missing_keys" are equal, but "unexpected_keys" might differ across ranks.
Args:
ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data)
constructed based on the checkpoint content
local_metadata (_LocalMetadata): list of local ShardedBase objects
requested to be loaded by this rank
global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects
requested to be loaded by all ranks. Defaults to None, in which case
returned "missing" keys are empty.
Returns:
Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. If passed
`global_metadata` is empty, returned missing keys are empty as well.
"""
local_accessed_keys = set(sh_base.key for sh_base in local_metadata)
ckpt_keys = set(sh_base.key for sh_base in ckpt_sharded_metadata.values())
unexpected_keys = local_accessed_keys - ckpt_keys
if global_metadata is not None:
global_accessed_keys = set(
sh_base.key for rank_metadata in global_metadata for sh_base in rank_metadata
)
missing_keys = ckpt_keys - global_accessed_keys
else:
missing_keys = set()
if missing_keys:
logger.debug(f'Dist ckpt load missing keys: {missing_keys}')
if unexpected_keys:
logger.debug(f'Dist ckpt load unexpected keys: {unexpected_keys}')
return missing_keys, unexpected_keys
def maybe_report_missing_and_unexpected_keys(
missing_keys: Set[str], unexpected_keys: Set[str], raise_error: bool = True
) -> None:
"""Raises or logs an error in case missing or unexpected keys are non-empty.
Args:
missing_keys (Set[str]): missing keys in the state dict
unexpected_keys (Set[str]): unexpected keys in the state dict
raise_error: If True, raises error on mismatch. Otherwise, logs mismatch
with WARNING level.
Returns:
None
Raises:
CheckpointingException: if `raise_error` is True and at least one of
`missing_keys` or `unexpected_keys` are non-empty.
"""
if not missing_keys and not unexpected_keys:
return
missing_title_msg = (
f'Some keys found in the checkpoint are missing in the provided sharded state dict. '
)
missing_body_msg = f'Missing keys (for all ranks): {missing_keys}. '
unexpected_title_msg = f'Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. '
unexpected_body_msg = f'Unexpected keys (for this rank): {unexpected_keys}. '
error_msg = ''
if missing_keys:
error_msg += missing_title_msg
if unexpected_keys:
error_msg += unexpected_title_msg
error_msg += '\n'
if missing_keys:
error_msg += missing_body_msg
if unexpected_keys:
error_msg += unexpected_body_msg
if raise_error:
raise CheckpointingException(error_msg)
else:
logger.warning(error_msg)
def validate_sharding_integrity(global_metadata: _GlobalMetadata) -> None:
"""Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(global_metadata):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
if isinstance(shardings[0][1], ShardedObject):
_validate_objects_for_key(shardings)
else:
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
)
shard_access_cnt = _compute_shards_access(rank_sharding)
if has_flattened_range:
map_reduce(
rank_sharding,
lambda x: x[1].global_offset,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
"""Ensure uniqueness of saved objects."""
unique_keys = [
sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
if len(unique_keys) != expected_shard_num:
err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
logger.error(f'{err_msg} Existing shards: {unique_keys}')
raise CheckpointingException(err_msg)
def determine_global_metadata(
sharded_state_dict: ShardedStateDict,
) -> Tuple[_LocalMetadata, _GlobalMetadata]:
"""Exchanges local metadata with `all_gather_object` to determine global metadata.
Args:
sharded_state_dict (ShardedStateDict): local sharded state dict
Returns:
Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data
"""
local_metadata = [ten.without_data() for ten in nested_values(sharded_state_dict)]
global_metadata = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(global_metadata, local_metadata)
return local_metadata, global_metadata
def validate_sharded_objects_handling(
sharded_strategy: Union[SaveShardedStrategy, LoadShardedStrategy],
common_strategy: Union[SaveCommonStrategy, LoadCommonStrategy],
) -> None:
"""Checks if either of the passed strategies can handle sharded objects.
Args:
sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading
common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading
Returns:
None
Raises:
CheckpointingException: if both strategies can't handle ShardedObjects
"""
if (
not sharded_strategy.can_handle_sharded_objects
and not common_strategy.can_handle_sharded_objects
):
raise CheckpointingException(
f'Either sharded strategy or common strategy must implement ShardedObjects handling.'
f' Both {sharded_strategy} and {common_strategy} specify can_handle_sharded_objects=False'
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
# For backwards compatibility. ParamAndGradBuffer will be deprecated in future release.
# ParamAndGradBuffer (which is an alias of _ParamAndGradBuffer) is not intended to be
# consumed directly by external code.
from .param_and_grad_buffer import ParamAndGradBuffer
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from contextlib import contextmanager
import torch
from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
logger = logging.getLogger(__name__)
class DistributedDataParallel(MegatronModule):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
):
super().__init__(config=config)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * parallel_state.get_data_parallel_world_size()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if parallel_state.get_pipeline_model_parallel_rank() > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.param_to_bucket_group = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
self.params_with_grad = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
# Track params with grad to enable direct setting
# of param.grad_added_to_main_grad
self.params_with_grad.append(param)
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def _allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor
):
param_and_grad_dtype_to_params = {}
param_and_grad_dtype_to_offsets = {}
param_and_grad_dtype_to_indices = {}
# Group parameters by their gradient type.
for param in input_params:
assert param.requires_grad
param_dtype = param.dtype
if is_float8tensor(param):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype = torch.uint8
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0)
param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1
indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), [])
indices.append(offset)
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices
if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ torch.distributed.get_world_size(group=data_parallel_group)
== target_gradient_scaling_factor
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
_ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)],
)
)
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
# bucket group.
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
num_bucket_groups = len(bucket_groups)
for i in range(1, num_bucket_groups):
bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (
bucket_groups[num_bucket_groups - i - 1]
)
# Create map from param to bucket group, used in pre_hook.
for bucket_group in bucket_groups:
for bucket in bucket_group.buckets:
for param in bucket.params_list:
self.param_to_bucket_group[param] = bucket_group
return buffers, bucket_groups
if config.calculate_per_token_loss:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size()
)
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self.buffers, self.bucket_groups = _allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(with_context_parallel=True),
gradient_scaling_factor=gradient_scaling_factor,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
_allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_backward_post_hook(param))
self.grad_accs.append(grad_acc)
self.use_forward_hook = (
self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather
)
self.remove_forward_pre_hook_handles = {}
if self.use_forward_hook:
self.enable_forward_pre_hook()
self.overlap_param_gather_with_optimizer_step = False
def enable_forward_pre_hook(self):
"""
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert self.use_forward_hook
assert len(self.remove_forward_pre_hook_handles) == 0
# Register forward pre-hook for all sub-modules.
for module in self.module.modules():
self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook(
self._make_forward_pre_hook()
)
def disable_forward_pre_hook(self):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules.
for module in self.module.modules():
assert self.remove_forward_pre_hook_handles[module] is not None
self.remove_forward_pre_hook_handles[module].remove()
del self.remove_forward_pre_hook_handles[module]
assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters.
self.start_param_sync(force_sync=True)
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
def _make_forward_pre_hook(self):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather).
"""
def hook(module, *unused):
assert (
self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True"
# Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a
# .requires_grad field equal to False).
if param not in self.param_to_bucket_group:
continue
assert param.requires_grad
# If aligning param all-gather across pipeline stages, all-gather is dispatched
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
# If overlapping param all-gather with optimizer step, then all-gather has
# already been dispatched in optimizer step.
skip_next_bucket_dispatch = (
self.ddp_config.align_param_gather
or self.overlap_param_gather_with_optimizer_step
)
self.param_to_bucket_group[param].finish_param_sync(
skip_next_bucket_dispatch=skip_next_bucket_dispatch
)
return hook
def _make_backward_post_hook(self, param: torch.nn.Parameter):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def hook(*unused):
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
self.param_to_bucket_group[param].register_grad_ready(param)
return hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = False
try:
yield
finally:
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync:
# If overlapping param AG with optimizer step, AG should not be dispatched again
# in forward_backward_step.
if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
return
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_param_sync(force_sync=force_sync)
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync()
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.params_with_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group(
with_context_parallel=True
)
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute."""
align_param_gather: bool = False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .. import parallel_state
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad = weight.main_grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and torch.distributed.get_world_size(parallel_state.get_position_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
assert hasattr(model_module, 'position_embeddings')
grad = model_module.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
grad = param.main_grad
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
pp_group = parallel_state.get_pipeline_model_parallel_group()
if not isinstance(last_rank, list):
assert not isinstance(last_rank, list)
last_rank = [last_rank]
assert not isinstance(pp_group, list)
pp_group = [pp_group]
# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list = []
for lr, group in zip(last_rank, pp_group):
torch.distributed.broadcast(num_tokens, src=lr, group=group)
num_tokens_list.append(torch.clone(num_tokens))
assert all(x.item() == num_tokens_list[0] for x in num_tokens_list)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import math
import os
import warnings
from enum import Enum
from typing import Dict, List, Optional
import torch
from torch.distributed import _coalescing_manager
from ..utils import is_float8tensor, log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__)
class BufferType(Enum):
"""
Enumeration for buffer type.
"""
PARAM = 1
GRAD = 2
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
]
return sharded_buffer
class _ParamAndGradBucket:
"""
Bucket to keep track of a subset of the model's parameters and gradients.
Args:
params: List of parameters whose gradients are collated in this bucket.
param_data: View in ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
bucket_id: Index of bucket in buffer.
"""
def __init__(
self,
params: List[torch.nn.Parameter],
param_data: Optional[torch.Tensor],
grad_data: torch.Tensor,
offset: int,
numel_unpadded: int,
gradient_scaling_factor: float,
bucket_id: int,
):
self.params_list = params
self.params = set(params)
# Make sure there are no duplicate params.
assert len(self.params_list) == len(self.params)
self.param_data = param_data
self.grad_data = grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self.offset = offset
self.numel_unpadded = numel_unpadded
self.gradient_scaling_factor = gradient_scaling_factor
self.bucket_id = bucket_id
class _ParamAndGradBucketGroup:
"""
Put multiple buckets into a group so that their communications can be aggregated together.
Provides functionality to register when params in the bucket group have grads ready to be
synced; an asynchronous communication call is automatically launched when _all_ params in
the bucket group have grads ready.
Args:
buckets: A list of buckets.
ddp_config: DistributedDataParallel config object.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
"""
def __init__(
self,
buckets: List[_ParamAndGradBucket],
ddp_config: DistributedDataParallelConfig,
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_world_size: int,
):
self.buckets = buckets
self.ddp_config = ddp_config
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
# State for bookkeeping: params is the set of parameters this bucket group is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self.param_to_bucket = {}
self.params = set()
for bucket in self.buckets:
for param in bucket.params_list:
self.param_to_bucket[param] = bucket
self.params.add(param)
self.next_param_gather_bucket_group = None
self.reset()
self.param_gather_handle = None
self.param_gather_dispatched = False
self.grad_reduce_handle = None
def reset(self):
"""
Reset metadata in bucket group in preparation for the next iteration of training.
"""
self.params_with_grad = set()
self.is_last_microbatch = True
def check_for_nan_in_grad(self):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
global_rank = torch.distributed.get_rank()
norm_is_nan = self.buckets[0].grad_data.norm(p=2).isnan()
for i in range(1, len(self.buckets)):
norm_is_nan.logical_or_(self.buckets[i].grad_data.norm(p=2).isnan())
assert not norm_is_nan, (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backward pass before data-parallel communication collective. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
def start_param_sync(self, force_sync: bool = False):
"""
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous
communication call (unless force_sync is True). When ddp_config.overlap_param_gather
is set to False, makes synchronous call.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert self.ddp_config.use_distributed_optimizer
if force_sync:
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
return
else:
assert self.param_gather_handle is None
async_op = self.ddp_config.overlap_param_gather and not force_sync
# Coalesce communication kernels across buckets in the bucket group.
with _coalescing_manager(self.data_parallel_group, async_ops=async_op) as cm:
for bucket in self.buckets:
local_data_view = shard_buffer(bucket.param_data, self.data_parallel_world_size)[
self.data_parallel_rank
]
torch.distributed._all_gather_base(
bucket.param_data,
local_data_view,
group=self.data_parallel_group,
async_op=async_op,
)
if async_op:
self.param_gather_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._all_gather_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.param_gather_handle = None
self.param_gather_dispatched = True
def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
"""
Finishes param sync communication operation for this bucket. Dispatches
next bucket's param sync if available, unless skip_next_bucket_dispatch
is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous
communication call to complete (and dispatches one if one is not already
outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to
False.
Args:
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert self.ddp_config.use_distributed_optimizer
assert self.ddp_config.overlap_param_gather
# If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
# AG bucket in first model chunk if ddp_config.align_param_gather is False).
if not self.param_gather_dispatched:
self.start_param_sync()
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
# Dispatch next bucket's asynchronous param AG.
if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
self.next_param_gather_bucket_group.start_param_sync()
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.grad_reduce_handle is None
), 'Should not have multiple communication calls outstanding at once'
if self.ddp_config.check_for_nan_in_grad:
self.check_for_nan_in_grad()
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
for bucket in self.buckets:
if bucket.gradient_scaling_factor != 1.0:
bucket.grad_data *= bucket.gradient_scaling_factor
# Decide reduce_op.
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# Use async communications only when overlap_grad_reduce is True.
async_op = self.ddp_config.overlap_grad_reduce
# Coalesce communication kernels across buckets in the bucket group.
with _coalescing_manager(self.data_parallel_group, async_ops=async_op) as cm:
for bucket in self.buckets:
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(bucket.grad_data, self.data_parallel_world_size)[
self.data_parallel_rank
]
torch.distributed._reduce_scatter_base(
local_data_view,
bucket.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
)
else:
torch.distributed.all_reduce(
bucket.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
)
if async_op:
self.grad_reduce_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._reduce_scatter_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.grad_reduce_handle = None
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self.param_gather_dispatched = False
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
assert self.grad_reduce_handle is not None, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.grad_reduce_handle.wait()
self.grad_reduce_handle = None
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce
is True.
"""
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
if self.is_last_microbatch:
assert param in self.param_to_bucket, 'Param is not in the bucket group'
assert param not in self.params_with_grad, 'Cannot set grad twice'
self.params_with_grad.add(param)
# If all params in bucket group have grads available, issue communication call.
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()
class _ParamAndGradBuffer:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
param_indices: The index of each param among the params with same dtype, if a param is fp8,
use its "fake" high precision dtype to determine which params have same dtype with it.
These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
param_indices: List[int],
):
self.ddp_config = ddp_config
self.params = params
self.param_indices = param_indices
# Check that params are unique.
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
# Store attributes that will be needed later.
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128))
return bucket_end_index
def _pad_start_of_param_if_needed(param_start_index: int) -> int:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if self.ddp_config.use_distributed_optimizer:
# Ensure that params start at 128-byte aligned addresses (64 values
# since params are >= 16-bit precision).
return _pad(param_start_index, 64)
return param_start_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
param_start_index = 0
bucket_start_index = param_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _update_bucket_metadata(param_end_index: int) -> int:
"""
Record metadata for the bucket starting at bucket_start_index and ending with the
passed-in param_end_index. Returns the bucket's end_index.
"""
nonlocal bucket_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(param_end_index - bucket_start_index)
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
# Record metadata of new bucket.
self.bucket_indices.append((bucket_start_index, bucket_end_index))
bucket_start_index = bucket_end_index
# Prepare for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded bucket_end_index.
return bucket_end_index
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order.
this_numel = param.data.nelement()
param_start_index = _pad_start_of_param_if_needed(param_start_index)
# Create bucket with collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param):
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current param_start_index.
if self.ddp_config.use_distributed_optimizer:
# Make sure new bucket is appropriately padded.
if param_start_index % self.data_parallel_world_size != 0:
param_start_index = _pad_end_of_bucket_if_needed(param_start_index)
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_start_index)
param_end_index = param_start_index + this_numel
self.param_index_map[param] = (param_start_index, param_end_index, bucket_id)
bucket_params.add(param)
# If we have enough elements already or the current param is part of the shared
# embedding layer and needs a separate bucket, form a new bucket.
if (
bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
bucket_end_index = _update_bucket_metadata(param_end_index)
param_start_index = bucket_end_index
else:
param_start_index = param_end_index
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_end_index)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = bucket_end_index
self.numel_unpadded = sum(per_bucket_numel_unpadded)
assert self.numel_unpadded <= self.numel
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
else:
assert self.numel == self.numel_unpadded
self.param_data = None
# Only re-map param tensors if using distributed optimizer.
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params = []
bucket_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
param_start_index, param_end_index, bucket_id = self.param_index_map[param]
# Assign param.data to appropriate segment of self.param_data.
if self.param_data is not None:
old_param_data = param.data
new_param_data = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.PARAM
)
if is_float8tensor(param):
param._data = new_param_data
else:
param.data = new_param_data
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
bucket_start_index = bucket_end_index
bucket_params = []
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.append(param)
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
# Log buckets for all PP stages.
log_strs = []
log_strs.append(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):')
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
self.grad_data *= scaling_factor
def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index = start_index + shape.numel()
assert end_index <= self.numel, 'Requested tensor is out of buffer range'
if buffer_type == BufferType.PARAM:
assert self.param_data is not None
buffer_tensor = self.param_data[start_index:end_index]
elif buffer_type == BufferType.GRAD:
buffer_tensor = self.grad_data[start_index:end_index]
else:
raise Exception("Illegal buffer type provided to GradBuffer._get() function")
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def _new_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
numel_unpadded: int,
bucket_id: int,
) -> _ParamAndGradBucket:
"""
Helper function that creates a new bucket. Also updates param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.ddp_config.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]
# Get appropriate view into global ParamAndGradBuffer.
bucketed_param_data = None
if self.param_data is not None:
bucketed_param_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
)
bucketed_grad_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD
)
bucket = _ParamAndGradBucket(
params=bucket_params,
param_data=bucketed_param_data,
grad_data=bucketed_grad_data,
offset=start_index,
numel_unpadded=numel_unpadded,
gradient_scaling_factor=self.gradient_scaling_factor,
bucket_id=bucket_id,
)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket
return bucket
def reset(self):
"""
Zero out the underlying grad_buffer.
"""
self.grad_data.zero_()
def partition_buckets(
buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False
) -> List[_ParamAndGradBucketGroup]:
"""
Automatically regroup the buckets of input buffers and return a list of bucket groups.
In some scenarios, we need to put buckets from different buffers into a group so that their
communication can be aggregated.
For example, when there are both fp8 weights and bf16 biases in the model and virtual
pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket,
which doubles the number of communication kernels, and because of the use of
CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the
overlap of communication kernels with computation kernels.
The grouping strategy is:
1. If force_single_bucket_group is True, put all buckets across all buffers into a single
bucket group.
2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers,
let each bucket group have only one bucket.
3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets
into the last fp8 bucket group.
- Since the non-fp8 parameters (typically the biases of various layers) are relatively
small, they are likely to be grouped into a single non-fp8 bucket.
- The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to
the end of the model, while the last bucket corresponds to the beginning.
- If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the
reduce-scatter to synchronize gradients after the backward pass at the end of the model
has completed. This is because we need to wait for the non-fp8 params from the beginning
layers to obtain their gradients.
- Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue.
Args:
buffers (list): list of input buffers.
single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer
into a single bucket group.
"""
if len(buffers) == 0:
return []
dtype_to_buffer_map = {}
for buffer in buffers:
dtype = buffer.param_dtype
# Make sure that the param_dtype of any two buffers is different.
assert dtype not in dtype_to_buffer_map
dtype_to_buffer_map[dtype] = buffer
# Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True.
if force_single_bucket_group:
buckets = []
ddp_config = buffers[0].ddp_config
data_parallel_group = buffers[0].data_parallel_group
data_parallel_world_size = buffers[0].data_parallel_world_size
for buffer in buffers:
assert ddp_config == buffer.ddp_config
assert data_parallel_group == buffer.data_parallel_group
assert data_parallel_world_size == buffer.data_parallel_world_size
buckets.extend(buffer.buckets)
bucket_group = _ParamAndGradBucketGroup(
buckets, ddp_config, data_parallel_group, data_parallel_world_size
)
return [bucket_group]
if torch.uint8 not in dtype_to_buffer_map:
# Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have
# only one bucket.
bucket_groups = []
for buffer in buffers:
for bucket in buffer.buckets:
bucket_groups.append(
_ParamAndGradBucketGroup(
[bucket],
buffer.ddp_config,
buffer.data_parallel_group,
buffer.data_parallel_world_size,
)
)
return bucket_groups
else:
# Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group.
non_fp8_buckets = []
for buffer in buffers:
if buffer.param_dtype != torch.uint8:
for bucket in buffer.buckets:
non_fp8_buckets.append(bucket)
bucket_groups = []
fp8_buffer = dtype_to_buffer_map[torch.uint8]
for bucket in fp8_buffer.buckets:
if len(bucket_groups) == len(fp8_buffer.buckets) - 1:
# The last bucket group.
group_buckets = [bucket] + non_fp8_buckets
else:
# The first N-1 bucket groups.
group_buckets = [bucket]
bucket_groups.append(
_ParamAndGradBucketGroup(
group_buckets,
buffer.ddp_config,
buffer.data_parallel_group,
buffer.data_parallel_world_size,
)
)
return bucket_groups
# For backwards compatibility. ParamAndGradBuffer will be deprecated in future release.
# _ParamAndGradBuffer is not intended to be consumed directly by external code.
class ParamAndGradBuffer(_ParamAndGradBuffer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"`ParamAndGradBuffer` will be deprecated in a future release, and is not "
"intended to be used by external code."
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import os
import warnings
from typing import Callable
import torch
import transformer_engine as te
from packaging.version import Version as PkgVersion
from torch import Tensor
from megatron.core import ModelParallelConfig, parallel_state
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_tensor_and_expert_parallel_world_size,
get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import get_te_version, is_te_min_version
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype}
if is_te_min_version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
"""Condition TE init_method on config.perform_initialization."""
return init_method if config.perform_initialization else (lambda w: None)
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_split_ag"] = False
extra_kwargs["ub_atomic_gemm_ag"] = False
extra_kwargs["ub_split_rs"] = False
extra_kwargs["ub_atomic_gemm_rs"] = False
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert and self.expert_parallel:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name
# Disable communications in TE when using SP or EP by making TE agnostic of model parallel.
tp_size = self.config.tensor_model_parallel_size
tp_group = get_tensor_model_parallel_group(check_initialized=False)
if is_expert and (self.config.sequence_parallel or self.expert_parallel):
if self.config.moe_extended_tp:
tp_size = get_tensor_and_expert_parallel_world_size()
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if is_te_min_version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
te_version = get_te_version()
raise ValueError(
f"Transformer Engine v{te_version} does not support {self.config.normalization}."
)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if is_te_min_version("1.5.0", check_equality=False):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if is_te_min_version("1.6.0.dev0", check_equality=False):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
eps=self.config.layernorm_epsilon,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
return_layernorm_output=False,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**extra_kwargs,
)
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if not input_is_parallel:
raise ValueError(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True
# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if is_te_min_version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
else:
assert (
self.config.context_parallel_size == 1
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs['window_size'] = config.window_size
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=(
self.config.attention_dropout if attention_dropout is None else attention_dropout
),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
packed_seq_params: PackedSeqParams = None,
):
"""Forward."""
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
if self.te_forward_mask_type:
if qkv_format == 'thd' and is_te_min_version("1.7.0"):
# thd format uses flash attention with cuDNN kernel which requires is_padding=True,
# so the only acceptable mask types are `padding_causal` and `padding`. These do not
# necessarily indicate there are padded tokens in the sequence.
if attn_mask_type == AttnMaskType.causal:
attn_mask_type = AttnMaskType.padding_causal
elif attn_mask_type == AttnMaskType.no_mask:
attn_mask_type = AttnMaskType.padding
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
if is_te_min_version("1.9.0.dev0"):
class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if self.expert_parallel:
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
# For MoE models, the comms between TP and EP group is explicitly handled by
# MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
self.explicit_expert_comm = is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
tp_group = get_tensor_model_parallel_group(check_initialized=False)
if self.explicit_expert_comm and config.moe_extended_tp:
tp_size = parallel_state.get_tensor_and_expert_parallel_world_size()
else:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
num_gemms=num_gemms,
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def forward(self, x, m_splits):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def _sharded_state_dict_grouped(
self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None
):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_gemms
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_gemms
)
ep_axis = len(sharded_offsets)
for gemm_idx in range(self.num_gemms):
state_dict = {
f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'],
f'{gemm_idx}._extra_state': full_state_dict['_extra_state'],
}
if self.use_bias:
state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}']
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict,
'',
tp_axis_map,
(
*sharded_offsets,
(ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
),
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix)
sharded_state_dict.update(
{
f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'],
# TODO: TE's GroupedLinear only has one _extra_state for all experts.
# We need sharding or build/merge fn to handle _extra_state correctly.
f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[
f'{gemm_idx}._extra_state'
],
}
)
if self.use_bias:
sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias']
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_data_modulo_expert_parallel_rank(),
)
return sharded_state_dict
class TEColumnParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {}
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
class TERowParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
else:
TEGroupedLinear = None
TEColumnParallelGroupedLinear = None
TERowParallelGroupedLinear = None
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(
self,
config: ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
):
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
if get_te_version() < PkgVersion("1.8.0"):
extra_kwargs["interval"] = config.fp8_interval
elif config.fp8_interval != 1:
warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.")
super().__init__(
margin=config.fp8_margin,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""
def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized
def reset(self):
"""Reset the internal RNG state."""
super().reset()
self._is_initialized = False
def set_states(self, states):
"""Set the internal RNG state."""
super().set_states(states)
self._is_initialized = True
def add(self, name, seed):
"""Track the rng state."""
super().add(name, seed)
self._is_initialized = True
def te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
):
"""Checkpointing with Transformer-Engine."""
from transformer_engine.pytorch.distributed import checkpoint
if is_te_min_version("1.5.0"):
return checkpoint(
forward_func,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
distribute_saved_activations=distribute_saved_activations,
get_rng_state_tracker=get_rng_state_tracker,
tp_group=tp_group,
)
else:
return checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
SplitAlongDim = _SplitAlongDim.apply
except ImportError:
SplitAlongDim = None
try:
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context as _get_cpu_offload_context,
)
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
"""Get CPU offload context and sync function."""
if is_te_min_version("1.10.0.dev0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
else:
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, activation_offloading, weight_offloading
)
return context, sync_func
except ImportError:
get_cpu_offload_context = None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Tuple
import torch
from megatron.core.jit import jit_fuser
def _bias_dropout_add_func(x, bias, residual, prob, training):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
def _bias_dropout_add_func(x_with_bias, residual, prob, training):
# type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
x, bias = x_with_bias # unpack
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
# The Dropout operation, Residual Addition and the tensor returning can be
# done generically outside the if statement, but that stops fusing of Bias
# Addition-Dropout-Residual Addition operation. So doing it together inside
# the conditional branch to improve performance
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
else:
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def bias_dropout_add_unfused(training):
def _bias_dropout_add(x_with_bias, residual, prob):
return _bias_dropout_add_func(x_with_bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
@jit_fuser
def bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
) -> torch.Tensor:
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, True)
return _bias_dropout_add_func(x_with_bias, residual, prob, True)
@torch.jit.script
@jit_fuser
def bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
) -> torch.Tensor:
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, False)
return _bias_dropout_add_func(x_with_bias, residual, prob, False)
def get_bias_dropout_add(training, fused):
def unfused_bias_dropout_add(x_with_bias, residual, prob):
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, training)
if fused:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
......@@ -57,4 +70,4 @@ def get_bias_dropout_add(training, fused):
else:
return bias_dropout_add_fused_inference
else:
return unfused_bias_dropout_add
return bias_dropout_add_unfused(training)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.jit import jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@jit_fuser
def geglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return (y_1 * 0.5 * (1.0 + torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)))) * y_2
@jit_fuser
def bias_geglu(bias, y):
y = y + bias
return geglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def geglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
tanh_out = torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * y_1 * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * y_1 * y_1)) + 0.5 * (
1 + tanh_out
)
return torch.cat(((g * y_2) * ff, g * (y_1 * 0.5 * (1.0 + tanh_out))), -1)
@jit_fuser
def bias_geglu_back(g, y, bias):
y = y + bias
return geglu_back(g, y)
class BiasGeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_geglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_geglu_back(grad_output, input, bias)
return tmp, tmp
class GeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return geglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = geglu_back(grad_output, input[0])
return tmp
def bias_geglu_impl(input, bias):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasGeGLUFunction.apply(input, bias)
else:
output = GeGLUFunction.apply(input)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
......@@ -2,7 +2,9 @@
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
from megatron.core.jit import jit_fuser
# BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
......@@ -11,7 +13,7 @@ import torch
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
@jit_fuser
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
......@@ -20,7 +22,7 @@ def bias_gelu(bias, y):
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
@jit_fuser
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
......@@ -44,5 +46,10 @@ class GeLUFunction(torch.autograd.Function):
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
# This is required to make Sphinx happy :-(
@classmethod
def apply(cls, *args, **kwargs):
return super().apply(*args, **kwargs)
bias_gelu_impl = GeLUFunction.apply
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F
from megatron.core.jit import jit_fuser
###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################
@jit_fuser
def swiglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return F.silu(y_1) * y_2
@jit_fuser
def bias_swiglu(y, bias):
y = y + bias
return swiglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def swiglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
return torch.cat(
(g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1
)
@jit_fuser
def bias_swiglu_back(g, y, bias):
y = y + bias
return swiglu_back(g, y)
class BiasSwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
ctx.save_for_backward(input_for_backward, bias)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return bias_swiglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = bias_swiglu_back(grad_output, input, bias)
return tmp, tmp, None
class SwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
ctx.save_for_backward(input_for_backward)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return swiglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = swiglu_back(grad_output, input)
return tmp, None
def bias_swiglu_impl(input, bias, fp8_input_store=False):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store)
else:
output = SwiGLUFunction.apply(input, fp8_input_store)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import torch
from megatron.core.jit import jit_fuser
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from megatron.core.tensor_parallel.utils import VocabUtility
@jit_fuser
def calculate_logits_max(vocab_parallel_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max(
vocab_parallel_logits
)
return vocab_parallel_logits, logits_max
@jit_fuser
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
(target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = (
VocabParallelCrossEntropy.calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
predicted_logits_sum_exp_logits = torch.cat((predicted_logits, sum_exp_logits))
return target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits
@jit_fuser
def calculate_cross_entropy_loss(
exp_logits: torch.Tensor, predicted_logits_sum_exp_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
split_val = predicted_logits_sum_exp_logits.size()[0] // 2
predicted_logits, sum_exp_logits = torch.split(predicted_logits_sum_exp_logits, split_val)
exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss(
exp_logits, predicted_logits, sum_exp_logits
)
return exp_logits, loss
@jit_fuser
def calculate_gradients(
softmax: torch.Tensor,
grad_output: torch.Tensor,
target_mask: torch.Tensor,
masked_target_1d: torch.Tensor,
) -> torch.Tensor:
(grad_2d, arange_1d, softmax_update, grad_input) = (
VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
)
grad_input = VocabParallelCrossEntropy.calculate_gradients(
grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
)
grad_input = grad_input.to(torch.bfloat16)
return grad_input
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits)
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
(target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits) = (
calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
# All reduce is needed to get the chunks from other GPUs.
# In the fused case, tensors are batches to invoke a single
# AllReduce call
torch.distributed.all_reduce(
predicted_logits_sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
exp_logits, loss = calculate_cross_entropy_loss(exp_logits, predicted_logits_sum_exp_logits)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
grad_input = calculate_gradients(softmax, grad_output, target_mask, masked_target_1d)
return grad_input, None
def fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
import inspect
import numbers
import torch
from torch import Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from megatron.core.transformer import TransformerConfig
from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except:
except ImportError:
HAVE_PERSIST_LAYER_NORM = False
try:
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM = True
except:
except ImportError:
HAVE_FUSED_LAYER_NORM = False
class FusedLayerNorm(torch.nn.Module):
"""Layer Norm, fused into a single CUDA kernel.
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def __init__(
self,
hidden_size,
eps=1e-5,
persist_layer_norm=True,
sequence_parallel=False,
zero_centered_gamma=False,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
persist_layer_norm: bool = True,
zero_centered_gamma: bool = False,
normalization: str = "LayerNorm", # included to match TE interface
):
super().__init__()
self.zero_centered_gamma = zero_centered_gamma
self.config = config
self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma
assert (
self.config.normalization == "LayerNorm"
), f'({self.config.normalization}) is not supported in FusedLayerNorm'
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
......@@ -66,22 +96,24 @@ class FusedLayerNorm(torch.nn.Module):
49152,
65536,
]
persist_layer_norm = self.config.persist_layer_norm
if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
persist_layer_norm = False
if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM:
# TODO: Add pytorch only layer norm
raise ValueError(f'Apex must currently be installed to use megatron core.')
raise ValueError(f'Apex must be installed to use FusedLayerNorm.')
if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = eps
self.weight = Parameter(torch.Tensor(*hidden_size))
self.bias = Parameter(torch.Tensor(*hidden_size))
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self.weight = Parameter(torch.empty(*hidden_size))
self.bias = Parameter(torch.empty(*hidden_size))
self.reset_parameters()
self.persist_layer_norm = persist_layer_norm
self.sequence_parallel = sequence_parallel
self.sequence_parallel = self.config.sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
......@@ -96,11 +128,16 @@ class FusedLayerNorm(torch.nn.Module):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
def forward(self, input: Tensor) -> Tensor:
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
if self.persist_layer_norm:
if 'memory_efficient' in inspect.getfullargspec(FastLayerNormFN.forward).args:
output = FastLayerNormFN.apply(
input, weight, self.bias, self.eps, self.config.memory_efficient_layer_norm
)
else:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
......@@ -112,7 +149,20 @@ class FusedLayerNorm(torch.nn.Module):
)
else:
output = FusedLayerNormAffineFunction.apply(
if (
'memory_efficient'
in inspect.getfullargspec(FusedLayerNormAffineFunction.forward).args
):
return FusedLayerNormAffineFunction.apply(
input,
weight,
self.bias,
self.hidden_size,
self.eps,
self.config.memory_efficient_layer_norm,
)
else:
return FusedLayerNormAffineFunction.apply(
input, weight, self.bias, self.hidden_size, self.eps
)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch
import torch.nn as nn
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.utils import get_default_causal_mask
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
......@@ -96,7 +98,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
Args:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
......@@ -131,7 +133,12 @@ class FusedScaleMaskSoftmax(nn.Module):
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor]):
"""Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed.
A user-defined mask is only needed when attn_mask_type is not causal.
"""
# [b, np, sq, sk]
assert input.dim() == 4
......@@ -186,6 +193,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.scale is not None:
input = input * self.scale
# Generate causal mask if not given
sq, sk = input.size(2), input.size(3)
if self.attn_mask_type == AttnMaskType.causal and mask is None and sq > 1:
# If sq == 1 then either KV cache is used or one-element context is passed
# so keeping mask=None in this case; subsequent code should handle it
assert sq == sk, "causal mask is only for self attention"
mask = get_default_causal_mask(sq)
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
warnings.warn(
"The 'megatron.core.inference.ammo_support' module is deprecated and will be removed in a future release. "
"Please use megatron.core.inference.modelopt_support instead",
DeprecationWarning,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.modelopt_support.gpt.model_specs import get_gpt_layer_modelopt_spec
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment