Commit d520d24f authored by silencealiang's avatar silencealiang
Browse files

Merge branch 'main' into 'main'

megatron升级v0.10

See merge request OpenDAS/megatron-lm!3
parents 3aca1415 481609bb
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
import logging
from time import time
from typing import Any, Callable, Optional
import torch
from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values
from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution
from .mapping import (
CommonStateDict,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
)
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
extract_nonpersistent,
extract_sharded_base,
)
from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
def save_preprocess(
sharded_state_dict: ShardedStateDict,
validate_access_integrity: bool = True,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
):
"""Preprocesses the given state dictionary by applying factories,
discarding non-persistent data and extracting the common state dictionary.
Optionally, it can validate sharding integrity.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed.
validate_access_integrity (bool): If True, triggers validation of sharding integrity.
preprocess_common_before_consistancy_check (callable, None): A callable function
that will preprocess the common state dict (i.e can be used to remove keys
that we expect to be different in the state dict)
Returns:
Tuple[ShardedStateDict, dict]:
The preprocessed sharded state dictionary and the common state dictionary.
"""
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
preprocessed_common_state_dict = common_state_dict
if preprocess_common_before_consistancy_check:
preprocessed_common_state_dict = preprocess_common_before_consistancy_check(
common_state_dict
)
validate_sharding_integrity(
determine_global_metadata(sharded_part)[1],
common_state_dict=preprocessed_common_state_dict,
)
return sharded_part, common_state_dict
def load_preprocess(sharded_state_dict: ShardedStateDict):
"""Preprocesses the given state dictionary by applying factories
and extracting non-persistent data, without modifying the original dictionary.
Args:
sharded_state_dict (ShardedStateDict):
The initial state dictionary to be processed (remains unchanged).
Returns:
Tuple[ShardedStateDict, dict, dict]:
- A preprocessed copy of the sharded state dictionary.
- A dictionary containing non-persistent state data.
- A dictionary of `ShardedTensorFactory` instances.
"""
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Data inside sh_ten_factories no longer needed so delete them to reduce memory usage
dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories)
# Non-persistent objects
nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories
def prepare_state_dict_for_save(
sharded_state_dict: ShardedStateDict,
async_prepare: bool = False,
algo: str = 'atomic',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
to_cpu: bool = True,
):
"""Creates a tensor-aware state dictionary that can be saved using the Local Checkpoint Manager.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary.
async_prepare (bool): If True, enables asynchronous preparation.
algo (str): The algorithm used to create the tensor-aware state dictionary.
validate_access_integrity (bool): If True, validates sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for exchanges to avoid duplications.
to_cpu (bool): If True, moves all tensors from device to CPU.
Returns:
ShardedStateDict: The tensor-aware state dictionary.
"""
_start = time()
if async_prepare:
raise NotImplementedError('Async state_dict preparation is not yet implemented')
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(sharded_state_dict, validate_access_integrity)
sharded_tensors = []
sharded_objects = []
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
sharded_tensors.append(sh_base)
else:
assert isinstance(sh_base, ShardedObject)
sharded_objects.append(sh_base)
if fully_parallel:
shard_to_saving_rank, _, shard_to_metadata = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
raw_tensors, raw_objects = {}, {}
for ten in sharded_tensors:
shard_id = _sharded_tensor_shard_id(ten)
if not fully_parallel or shard_to_saving_rank[shard_id] == torch.distributed.get_rank():
# TODO cover creating copies on host in CheckpointManager.save()
if to_cpu:
raw_tensors[shard_id] = ten.data.to("cpu", non_blocking=True)
else:
raw_tensors[shard_id] = ten.data
ten.data = None
for obj in sharded_objects:
raw_objects[_sharded_object_id(obj)] = obj.data
obj.data = None
logger.debug(f'prepare_state_dict_for_save took {time() - _start}')
state_dict_for_save = {
'raw_tensors': raw_tensors,
'raw_objects': raw_objects,
'common': common_state_dict,
'sharded_state_dict': sharded_part,
}
if fully_parallel:
state_dict_for_save['shard_to_rank'] = shard_to_saving_rank
state_dict_for_save['shard_to_metadata'] = shard_to_metadata
return state_dict_for_save
def recreate_state_dict_after_load(
sharded_state_dict: ShardedStateDict,
loaded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""Creates a final sharded state dictionary from a tensor-aware state dictionary.
Args:
sharded_state_dict (ShardedStateDict):
The initial sharded state dictionary generated from the model.
loaded_state_dict (ShardedStateDict):
Tensor-aware state dictionary used to fill in missing data in the sharded state.
algo (str): The algorithm used to reconstruct the state dictionary
from the tensor-aware state dictionary.
exchange_algo (str): The algorithm used for tensor exchanges during retrieval.
validate_access_integrity (bool): If True, performs validation of sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for efficient exchanges during retrieval.
Returns:
ShardedStateDict: The finalized sharded state dictionary.
"""
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
# __adding__ common part
recreated_state_dict, _ = extract_matching_values(loaded_state_dict["common"], lambda x: True)
if not sharded_state_dict:
return recreated_state_dict
# TODO validate laoded_state_dict["sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
loaded_tensors = loaded_state_dict['raw_tensors']
# TODO cover restoring the original device (H2D) in CheckpointManager.load()
for k, v in loaded_tensors.items():
loaded_tensors[k] = v.cuda() # H2D
if fully_parallel:
distribution = (
loaded_state_dict['shard_to_rank'],
None,
loaded_state_dict['shard_to_metadata'],
)
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
loaded_tensors = exchange_by_distribution(
loaded_tensors, unloaded_shards, distribution, parallelization_group, exchange_algo
)
loaded_objects = loaded_state_dict['raw_objects']
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
if shard_id not in loaded_tensors:
raise Exception(
'The current local checkpoint implementation assumes'
'consistent tensor sharding during load and save operations.'
f'However, the expected shard {x} (ID: {shard_id})'
f'was not found in the checkpoint. (IDs: {loaded_tensors.keys()})'
)
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part)
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies
import logging
logger = logging.getLogger(__name__)
try:
import tensorstore
import zarr
from .tensorstore import _import_trigger
from .zarr import _import_trigger
except ImportError:
logger.warning('Zarr-based strategies will not be registered because of missing packages')
# We load "common" strategies by default to be always available
register_default_common_strategies()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import logging
from collections import deque
from time import time
from typing import Callable, List, NamedTuple, Optional, Tuple
import torch
from torch import multiprocessing as mp
logger = logging.getLogger(__name__)
class AsyncRequest(NamedTuple):
"""Represents an async request that needs to be scheduled for execution.
Args:
async_fn (Callable, optional): async function to call. None represents noop.
async_fn_args (Tuple): args to pass to `async_fn`.
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
"""
async_fn: Optional[Callable]
async_fn_args: Tuple
finalize_fns: List[Callable]
is_frozen: bool = False
def add_finalize_fn(self, fn: Callable) -> None:
"""Adds a new finalize function to the request.
Args:
fn (Callable): function to add to the async request. This function
will be called *after* existing finalization functions.
Returns:
None
"""
if self.is_frozen:
raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest')
self.finalize_fns.append(fn)
def execute_sync(self) -> None:
"""Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
"""
if self.async_fn is not None:
self.async_fn(*self.async_fn_args)
torch.distributed.barrier()
for finalize_fn in self.finalize_fns:
finalize_fn()
def freeze(self) -> 'AsyncRequest':
"""Freezes the async request, disallowing adding new finalization functions.
Returns:
AsyncRequest: new async request with all same fields except for the
`is_frozen` flag.
"""
return self._replace(is_frozen=True)
class DistributedAsyncCaller:
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def __init__(self):
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None
def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple) -> None:
"""Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
save_args (Tuple): async function args.
"""
if async_fn is None:
return # nothing to do
start_sync = time()
torch.cuda.synchronize()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {end_sync - start_sync} to finish D2H "
)
ctx = mp.get_context('fork')
self.start_time = time()
self.process = ctx.Process(target=async_fn, args=save_args)
self.process.start()
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} to schedule async ckpt "
)
def is_current_async_call_done(self, blocking=False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead as torch.distributed.barrier (single integer all-reduce)
is_alive = int(self.process.is_alive()) if self.process is not None else 0
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
logger.debug(
f"rank: {torch.distributed.get_rank()}, DistributedAsyncCaller is_alive: {is_alive}"
)
torch.distributed.all_reduce(ten)
if ten[0] > 0 and not blocking:
return False
else:
if self.process is not None:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
return True
class _ActiveAsyncRequest(NamedTuple):
"""Helper to represent an active async call.
Args:
idx (int): index of the call (starting from 0)
async_caller (DistributedAsyncCaller): async caller instance that represents
the async process handling the async request
async_request (AsyncRequest): async request that is being called
"""
idx: int
async_caller: DistributedAsyncCaller
async_request: AsyncRequest
class AsyncCallsQueue:
"""Manages a queue of async calls.
Allows adding a new async call with `schedule_async_request` and finalizing
active calls with `maybe_finalize_async_calls`.
"""
def __init__(self):
self.async_calls: deque[_ActiveAsyncRequest] = deque([])
self.call_idx: int = -1
def schedule_async_request(self, async_request: AsyncRequest) -> int:
"""Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
Args:
async_request (AsyncRequest): async request to start.
Returns:
int: index of the async call that was started.
This can help the user keep track of the async calls.
"""
self.call_idx += 1
async_caller = DistributedAsyncCaller()
async_request = async_request.freeze()
async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
return self.call_idx
def maybe_finalize_async_calls(self, blocking=False) -> List[int]:
"""Finalizes all available calls.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
Returns:
List[int]: list of indices (as returned by `schedule_async_request`)
of async calls that have been successfully finalized.
"""
call_idx_finalized = []
while self.async_calls:
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking)
if not next_async_done:
break
call_idx, _, async_request = self.async_calls.popleft()
for finalize_fn in async_request.finalize_fns:
finalize_fn()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert (
ten.item() == call_idx
), 'Unmatched async calls. That probably means not all ranks are participating in async finalization'
call_idx_finalized.append(call_idx)
return call_idx_finalized
def get_num_unfinalized_calls(self):
"""Get the number of active async calls."""
return len(self.async_calls)
def close(self):
"""Finalize all calls upon closing."""
self.maybe_finalize_async_calls(blocking=True)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies base interfaces. """
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, DefaultDict, Union
from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict
from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from .async_utils import AsyncCallsQueue, AsyncRequest
class StrategyAction(Enum):
"""Specifies save vs load and sharded vs common action."""
LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
default_strategies = defaultdict(dict)
default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)
async_calls = AsyncCallsQueue()
def get_default_strategy(action: StrategyAction, backend: str, version: int):
"""Retrieves a default strategy for a given action, backend and version."""
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
from .tensorstore import register_default_tensorstore_strategies
register_default_tensorstore_strategies()
from .zarr import register_default_zarr_strategies
register_default_zarr_strategies()
elif backend == 'torch_dist':
error_hint = ' Please use PyTorch version >=2.1'
from .torch import register_default_torch_strategies
register_default_torch_strategies()
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. '
f'Error: {e}. Hint: {error_hint}'
) from e
try:
return default_strategies[action.value][(backend, version)]
except KeyError as e:
raise CheckpointingException(
f'Cannot find default strategy for: {(action, backend, version)}'
f'Cannot find a default strategy for: {(action.value, backend, version)}'
) from e
def register_default_strategy(
action: StrategyAction,
backend: str,
version: int,
strategy: Union['SaveStrategyBase', 'LoadStrategyBase'],
):
"""Adds a given strategy to the registry of default strategies.
Args:
action (StrategyAction): specifies save/load and sharded/common
backend (str): backend that the strategy becomes a default for
version (int): version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase): strategy to register
"""
default_strategies[action.value][(backend, version)] = strategy
class LoadStrategyBase(ABC):
"""Base class for a load strategy. Requires implementing checks for compatibility with a
given checkpoint version."""
@abstractmethod
def check_backend_compatibility(self, loaded_version):
def check_backend_compatibility(self, loaded_backend):
"""Verifies if this strategy is compatible with `loaded_backend`."""
raise NotImplementedError
@abstractmethod
def check_version_compatibility(self, loaded_version):
"""Verifies if this strategy is compatible with `loaded_version`."""
raise NotImplementedError
@property
def can_handle_sharded_objects(self):
"""Returns whether or not this strategy can handle loading ShardedObjects."""
return False
class SaveStrategyBase(ABC):
"""Base class for a save strategy. Requires defining a backend type and
version of the saved format."""
def __init__(self, backend: str, version: int):
self.backend = backend
self.version = version
@property
def can_handle_sharded_objects(self):
"""Returns whether or not this strategy can handle saving ShardedObjects."""
return False
def __str__(self):
return f'{self.__class__.__name__}({self.backend}, {self.version})'
class LoadCommonStrategy(LoadStrategyBase):
"""Load strategy for common (non-sharded) objects"""
@abstractmethod
def load(self, checkpoint_dir: Path):
def load_common(self, checkpoint_dir: Path):
"""Load common part of the checkpoint."""
raise NotImplementedError
@abstractmethod
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Load sharded objects from the checkpoint."""
raise NotImplementedError
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
"""Load just the metadata from the checkpoint."""
if not self.can_handle_sharded_objects:
return {}
raise NotImplementedError
class LoadShardedStrategy(LoadStrategyBase):
"""Load strategy for sharded tensors"""
@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Load the sharded part of the checkpoint."""
raise NotImplementedError
@abstractmethod
def load_tensors_metadata(self, checkpoint_dir: Path):
"""Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the
only useful information is tensors global shape and dtype).
"""
raise NotImplementedError(
f'Loading only tensors metadata not implemented for {self.__class__.__name__}'
)
def load_sharded_metadata(self, checkpoint_dir: Path):
"""Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply sharded keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors or ShardedObjects without any data and sharding.
"""
if not self.can_handle_sharded_objects:
return self.load_tensors_metadata(checkpoint_dir)
raise NotImplementedError(
f'Loading only sharded metadata not implemented for {self.__class__.__name__}'
)
def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str):
"""Remove all tensors whose key starts with key_prefix"""
raise NotImplementedError
class SaveCommonStrategy(SaveStrategyBase):
"""Save strategy for common (non-sharded) objects"""
@abstractmethod
def save(self, common_state_dict: StateDict, checkpoint_dir: Path):
def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
"""Save common part of the state dict."""
raise NotImplementedError
def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Save sharded objects from the state dict."""
raise NotImplementedError
class SaveShardedStrategy(SaveStrategyBase):
"""Save strategy for sharded tensors"""
@abstractmethod
def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Save the sharded part of the state dict."""
raise NotImplementedError
class AsyncSaveShardedStrategy(SaveShardedStrategy):
"""Save strategy suitable for async save."""
@abstractmethod
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
"""Perform preparation and return an AsyncRequest to the external caller.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint target directory
Returns:
AsyncRequest: represents the async save function and finalization function.
It is the caller responsibility to actually schedule the async save.
"""
raise NotImplementedError
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Each async strategy can be trivially used as a sync strategy."""
async_request = self.async_save(sharded_state_dict, checkpoint_dir)
# multiprocessing routines may cause issue when called on parent process
# We keep this verbose call for now
global async_calls
async_calls.schedule_async_request(async_request)
async_calls.maybe_finalize_async_calls(blocking=True)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
""" Common strategies. """
import logging
import os
from pathlib import Path
import torch
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict
from megatron.core.dist_checkpointing.strategies.base import (
SaveCommonStrategy,
StrategyAction,
register_default_strategy,
)
from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import CheckpointingException, ShardedObject, is_main_replica
from ..strategies.base import LoadCommonStrategy
COMMON_STATE_FNAME = 'common.pt'
logger = logging.getLogger(__name__)
def register_default_common_strategies():
"""Register default common strategies."""
register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy())
register_default_strategy(
StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1)
)
class TorchCommonSaveStrategy(SaveCommonStrategy):
"""Common save strategy leveraging native torch save/load."""
def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
"""Save common part of the state dict."""
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Save sharded objects from the state dict."""
for sh_obj in nested_values(sharded_objects_state_dict):
if is_main_replica(sh_obj.replica_id):
save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
def can_handle_sharded_objects(self):
"""This strategy can handle ShardedObjects."""
return True
class TorchCommonLoadStrategy(LoadCommonStrategy):
"""Common load strategy leveraging native torch save/load."""
def load_common(self, checkpoint_dir: Path):
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}')
raise CheckpointingException(err_msg) from e
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Replaces all ShardedObject from a given state dict with values loaded from the
checkpoint.
Args:
sharded_objects_state_dict (ShardedStateDict):
sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
Returns:
None: sharded state dict is modified in place
"""
def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try:
loaded_obj = torch.load(load_path)
except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(old_load_path)
except FileNotFoundError:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
if obj_subdir.exists():
obj_files = [f.name for f in obj_subdir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}'
)
else:
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint'
f' directory content: {ckpt_files}'
)
raise CheckpointingException(err_msg) from e
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict)
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
sharded_metadata = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir():
continue
shard_files = list(subdir.glob('shard_*.pt'))
if not shard_files:
continue
sh_objs = []
for shard_file in shard_files:
full_key = f'{subdir.name}/{shard_file.stem}'
sh_objs.append(ShardedObject.empty_from_unique_key(full_key))
# This is a backward-compatibility fix, where the last global shape is missing in the
# name
if sh_objs[0].global_shape[-1] < 0:
max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs))
for sh_obj in sh_objs:
sh_obj.global_shape = (*sh_obj.global_shape[:-1], max_last_offset + 1)
# Update the sharded state dict
for sh_obj in sh_objs:
sharded_metadata[sh_obj.unique_key] = sh_obj
return sharded_metadata
@property
def can_handle_sharded_objects(self):
"""This strategy can handle ShardedObjects."""
return True
def check_backend_compatibility(self, loaded_version):
pass
def check_version_compatibility(self, loaded_version):
pass
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import gc
import logging
import os
import queue
from contextlib import contextmanager
from itertools import chain
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
_results_queue = None
def _get_write_results_queue():
global _results_queue
if _results_queue is None:
ctx = mp.get_context('spawn')
_results_queue = ctx.Manager().Queue()
return _results_queue
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class FileSystemWriterAsync(FileSystemWriter):
"""
Async-enabled implementation of FileSystemWriter using file IO.
This class doesn't spawn the async process itself, relies on the external async mechanism.
Flow:
1. Call `write_data`
2. Externally start async process with `get_save_function_and_args` function and args
3. The async function to call is `writer_proxy_func` which calls
`write_preloaded_data` in multiple processes
After saving is finalized on all ranks:
4. Call `super().finish` with the results gathered in `self.writer_result`
Note that step (3) above can also be called synchronously.
Currently, it's assumed that a separate writer is created for each ckpt save
(intermediate state is stored as writer attributes).
"""
def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs):
super().__init__(*args, **kwargs)
if not self.single_file_per_rank:
raise NotImplementedError(
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
self.separation_hint = separation_hint
def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Args:
plan (SavePlan): save plan generated by the PyT Distributed compatible planner
planner (SavePlanner): save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in `self.write_buckets`
"""
storage_plan: _StoragePrefix = plan.storage_data
start = time()
logger.debug(f"thread_count: {self.thread_count}, time: {start}")
if self.separation_hint:
assert (
self.thread_count > 1
), "thread_count must be at least 2 if separation_hint is provided"
bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count
item_buckets = _split_by_size_and_type(bins, plan.items, self.separation_hint)
logger.debug(f"bucket_prep, time: {time() - start}")
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count = 0
def gen_file(prefix=""):
nonlocal file_count
file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
return file_name
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = []
for group_name, group_buckets in _split_by_separation_hint(
item_buckets, self.separation_hint
).items():
for bucket in group_buckets:
bytes_data = [
(item, planner.resolve_data(item))
for item in bucket
if item.type == WriteItemType.BYTE_IO
]
tensor_data = [
(item, planner.resolve_data(item).detach().to("cpu", non_blocking=True))
for item in bucket
if item.type != WriteItemType.BYTE_IO
]
if len(bytes_data) > 0 or len(tensor_data) > 0:
file_name = gen_file(prefix=group_name)
self.write_buckets.append(
(self.path / file_name, file_name, (bytes_data, tensor_data))
)
# Check if there is anything to write on this rank
if len(self.write_buckets) > 0:
assert len(self.write_buckets) <= self.thread_count, (
len(self.write_buckets),
self.thread_count,
)
self.results_queue = _get_write_results_queue()
else:
self.results_queue = None
end = time()
logger.debug(f"D2H and push, time: {end - start}")
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data
- arguments to that function
"""
if not self.write_buckets:
return None, ()
return (self.write_preloaded_data_multiproc, (self.write_buckets, self.results_queue))
@staticmethod
@_disable_gc()
def write_preloaded_data_multiproc(
write_buckets: List[WriteBucket], global_results_queue: mp.Queue
) -> None:
"""
Performs saving data to storage with multiple processes.
Starts predefined number of processes and uses 2 queues to make sure the results
are complete:
- local_results_queue - to send the actual results
- count_queue - small queue to mark worker as completed
Using just one queue disallowed proper exception handling.
This method is meant to be run in a forked subprocess.
Triggering GC during execution leads to CUDA errors
(cleaning up tensors owned by the parent process).
To prevent this, we disable the GC explicitly for this function with _disable_gc.
Args:
write_buckets (List[WriteBucket]): write plan
global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]]
(or an Exception) from parallel write processes to the main training process
Returns: None
"""
w_start = time()
write_results_or_exc: Union[dict, Exception] = dict()
ctx = mp.get_context('fork')
local_results_queue = ctx.Queue()
count_queue = ctx.JoinableQueue()
p_list = []
for i, write_bucket in enumerate(write_buckets):
try:
count_queue.put(i)
p_list.append(
ctx.Process(
target=FileSystemWriterAsync.write_preloaded_data,
args=(i, write_bucket, local_results_queue, count_queue, True),
)
)
except Exception as e:
err_msg = f'An error is caught while a proc {i} is created, error: {e}'
logger.error(err_msg)
write_results_or_exc = RuntimeError(err_msg)
if not isinstance(write_results_or_exc, Exception):
for p in p_list:
p.start()
logger.debug('FileSystemWriterAsync: collecting worker results...')
# To make sure all nodes are completed
count_queue.join()
# At this point, all workers completed, so the queue should have exactly
# `len(write_buckets)` items
for proc_idx in range(len(write_buckets)):
try:
local_proc_idx, local_results_or_exc = local_results_queue.get()
except queue.Empty:
write_results_or_exc = RuntimeError(
f'Unexpected empty `local_results_queue`'
f' (got only {proc_idx}/{len(write_buckets)} items)'
)
break
else:
if isinstance(local_results_or_exc, Exception):
err_msg = (
f"Local process {local_proc_idx} encountered"
f" an error: {local_results_or_exc}"
)
logger.error(err_msg)
write_results_or_exc = local_results_or_exc
break
else:
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()
logger.debug('FileSystemWriterAsync: collected worker results successfully')
global_results_queue.put(write_results_or_exc)
w_end = time()
logger.debug(
f"{w_end}, rank: {torch.distributed.get_rank()},"
f" write(sync,parallel): {w_end - w_start}"
)
@staticmethod
@_disable_gc()
def write_preloaded_data(
local_proc_idx: int,
write_bucket: WriteBucket,
results_queue: mp.SimpleQueue,
count_queue: mp.JoinableQueue,
use_fsync: bool,
) -> None:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the `queue`
"""
mem_before = _process_memory()
local_results = []
try:
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
with open(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(_write_item(stream, data, write_item, storage_key))
for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(_write_item(stream, tensor, write_item, storage_key))
if use_fsync:
os.fsync(stream.fileno())
local_output = (local_proc_idx, local_results)
except Exception as e:
local_output = (local_proc_idx, e)
results_queue.put(local_output)
# Signal this process is done.
count_queue.get()
count_queue.task_done()
mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before},"
f" before: {mem_before}, after: {mem_after}"
)
def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]:
"""Write all items from ``plan``."""
raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
def retrieve_write_results(self) -> List[WriteResult]:
"""
Turn the latest dict including write results from `self.results_queue`
into a single results lists. Includes error check.
Returns (List[WriteResult]): the list of write results
from all local processes performing the save.
"""
assert self.write_buckets is not None
if self.results_queue is None:
write_results_or_exc = {}
else:
try:
write_results_or_exc = self.results_queue.get_nowait()
except queue.Empty:
raise RuntimeError(f'results_queue should not be empty')
if isinstance(write_results_or_exc, Exception):
raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc
write_results: dict = write_results_or_exc
if len(write_results) != len(self.write_buckets):
raise RuntimeError(
f'Incomplete worker results (expected {len(self.write_buckets)},'
f' got {len(write_results)}. This probably indicates a worker failure.'
)
return list(chain.from_iterable(write_results.values()))
def _split_by_size_and_type(
bins: int, items: List[WriteItem], separation_hint: Optional[str] = None
) -> List[List[WriteItem]]:
"""
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
but with a fixed _item_size function.
Args:
bins (int): numbers of bins to split to
items (List[WriteItem]): list of write items
Returns (List[List[WriteItem]]): write items split to bins
"""
if bins == 1:
return [items]
bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_items.sort(key=_item_size, reverse=True)
# Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items):
buckets[i % bins].append(item)
# Then, assign tensors according to their sizes
for item in tensor_items:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)
return buckets
def _split_by_separation_hint(
buckets: List[List[WriteItem]], separation_hint: Optional[str] = None
) -> Dict[str, List[List[WriteItem]]]:
"""
Splits buckets into those whose keys begin with the separation_hint and those whose keys do not
Args:
buckets (List[List[WriteItem]]): buckets to split
separation_hint (Optional[str]): optional prefix to split on
Returns (Dict[str, List[List[WriteItem]]]): a dictionary
mapping the prefix to the relevant buckets
"""
bins = len(buckets)
buckets_with_separation_hint = {}
if separation_hint is not None:
buckets_default = [[] for _ in range(bins)]
buckets_hint = [[] for _ in range(bins)]
for i in range(bins):
for item in buckets[i]:
if item.index.fqn.startswith(separation_hint):
buckets_hint[i].append(item)
else:
buckets_default[i].append(item)
buckets_with_separation_hint[""] = buckets_default
buckets_with_separation_hint[separation_hint] = buckets_hint
else:
buckets_with_separation_hint[""] = buckets
return buckets_with_separation_hint
def _item_size(item: WriteItem) -> int:
"""
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size,
but fixes computing chunk size (with item.tensor_data.chunk.sizes)
Args:
item (WriteItem): write item to compute the size of
Returns (int): size of an item in bytes
"""
size = 1
assert item.tensor_data is not None
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.chunk.sizes:
size *= s
dtype = item.tensor_data.properties.dtype
return size * torch._utils._element_size(dtype)
def _process_memory() -> int:
"""
Get memory used by current process.
Returns (int): memory used by current process
"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
return mem_info.rss
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from pathlib import Path
from time import time
from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
extract_matching_values,
merge,
nested_values,
)
from megatron.core.dist_checkpointing.exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
)
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica
from megatron.core.dist_checkpointing.strategies.base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
SaveShardedStrategy,
)
from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId
from megatron.core.dist_checkpointing.validation import (
determine_global_metadata,
validate_sharding_integrity,
)
logger = logging.getLogger(__name__)
class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
"""Wraps arbitrary strategy and distributes the save during `save`.
The save distribution happens without any *data* communication.
Only the *metadata* is exchanged and based on data replication on different
ranks, we try to distribute the save as uniformly as possible.
This wrapper assumes, that setting `replica_id` to 0 will make the
underlying strategy do the saving on current rank. All the other `replica_id`s
are set to 1.
Currently, the save distribution is realized with a greedy algorithm
described in `distribute_shards_to_ranks`.
Args:
strategy (SaveShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for save
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
do_cache_distribution (bool, optional): whether to cache the save distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to True.
"""
def __init__(
self,
strategy: SaveShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
):
super().__init__(strategy.backend, strategy.version)
self.base_strategy = strategy
self.parallelization_group = parallelization_group
self.do_cache_distribution = do_cache_distribution
self.cached_distribution: Optional[ShardDistribution] = None
def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if not isinstance(self.base_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async base strategy {self.base_strategy}'
)
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir)
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.save(sharded_state_dict, checkpoint_dir)
def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) -> None:
"""Distributes the save across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of saves among the ranks.
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the saving
Returns: None
"""
start = time()
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* save parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply save parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.cached_distribution is None:
# First time applying the parallelization
validate_sharding_integrity(determine_global_metadata(sharded_state_dict)[1])
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
end = time()
logger.debug(f"parallel save sharding, time: {end - start}")
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
"""Wraps arbitrary load strategy and distributes the load during `load`.
See `load` method docs for details.
Args:
strategy (LoadShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for load
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
In most cases, it's recommended to set it to the DP group.
do_cache_distribution (bool, optional): whether to cache the load distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to False,
since the loading in general happens only once during training.
Note that the load distribution *cannot* be reused as a save distribution,
because save/load is not fully symmetrical.
exchange_algo (str): algorithm to use for exchanging the data.
Options:
- broadcast - each rank broadcasts individual tensors to others
- gather_object (default) - ranks all_gather_object the whole loaded state dicts
- gather_rounds (default) - ranks all gather individual tensors in rounds
See method docs for more details.
"""
def __init__(
self,
strategy: LoadShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
exchange_algo: str = 'broadcast',
):
super().__init__()
self.base_strategy = strategy
if parallelization_group is None:
parallelization_group = (
dist.GroupMember.WORLD
) # explicit group needed for torch.distributed.get_global_rank call
self.parallelization_group = parallelization_group
self.do_cache_distribution = do_cache_distribution
self.exchange_algo = exchange_algo
self.cached_distribution: Optional[ShardDistribution] = None
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Distributes the load and calls underlying strategy only for parts of the state dict.
Steps:
1. Load metadata is exchanged between the ranks in the parallelization group.
2. Each rank deterministically plans the load for the whole workload
so that the loads are as uniform as possible.
3. Each ranks loads its planned shard of the checkpoint.
4. All ranks exchange the loaded shards.
Internode communication is involved in steps (1) (with metadata)
and (4) (with actual data). Storage interaction is involved in step (3).
Currently, the load distribution (step 2) is realized with a greedy algorithm
described in `distribute_shards_to_ranks` (same as for saving distribution).
Currently, the shards are all gathered between all ranks in the parallelization
group. This might not be optimal (some ranks do not need all tensors),
but it's a reasonable approximation for an optimal exchange in most scenarios.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory to load from
Returns:
StateDict: loaded state dict. The state dict should be equivalent to
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
if torch.distributed.get_world_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
# Step 1 and 2: exchange load metadata and distribute the load
start = time()
precomputed_distribution = self.apply_loading_parallelization(sharded_state_dict)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
end = time()
logger.debug(f'self.apply_loading_parallelization took {end - start}s')
start = end
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
# so that we can keep track of sharded tensors loaded by this rank
(sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = (
self._defer_loading_sharded_tensors(sharded_state_dict)
)
loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir)
end = time()
logger.debug(f'Base load of ShardedObjects took {end - start}s')
start = end
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
end = time()
logger.debug(f'Base load of ShardedTensors took {end - start}s')
start = end
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
)
sync_start = time()
torch.cuda.synchronize()
end = time()
logger.debug(f'torch.cuda.synchronize took {end - sync_start}s')
logger.debug(f'self.exchange_loaded_tensors took {end - start}s')
self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
merge(loaded_state_dict, sharded_tensors)
return loaded_state_dict
def _defer_loading_sharded_tensors(
self, sharded_state_dict: ShardedStateDict
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedTensor],
Dict[_ShardId, ShardedTensor],
]:
"""Divides state dict into parts loaded by this vs other ranks.
ShardedTensors with main replica_id will be loaded by this rank,
others will be received by other ranks (after loading from storage).
Args:
sharded_state_dict (ShardedStateDict): state dict with ShardedTensor
that will be divided.
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with ShardedTensors
- ShardedStateDict: sub-state dict with non-ShardedTensors
- Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified
by shard ids. This is a mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *this* rank
- Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_tensors, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedTensor)
)
def wrap_non_main_replicas(x):
if isinstance(x, ShardedTensor):
# Assign shard to be loaded or not
if is_main_replica(x.replica_id):
to_load_shards[_sharded_tensor_shard_id(x)] = x
else:
unloaded_shards[_sharded_tensor_shard_id(x)] = x
return x
dict_list_map_inplace(wrap_non_main_replicas, sharded_tensors)
return sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards
def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
) -> Optional[ShardDistribution]:
"""Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
ShardDistribution (optional): the computed loading distribution
"""
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* load parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply load parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group, True
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
return precomputed_distribution
def fill_in_deferred_sharded_tensors(
self, sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
"""
def fill_in_sharded_tensor(x):
if isinstance(x, ShardedTensor):
try:
x = loaded_tensors[_sharded_tensor_shard_id(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded tensor shard: {_sharded_tensor_shard_id(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_tensor, sharded_state_dict)
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
def load_tensors_metadata(self, checkpoint_dir: Path):
return self.base_strategy.load_tensors_metadata(checkpoint_dir)
def load_sharded_metadata(self, checkpoint_dir: Path):
return self.base_strategy.load_sharded_metadata(checkpoint_dir)
def check_backend_compatibility(self, loaded_version):
return self.base_strategy.check_backend_compatibility(loaded_version)
def check_version_compatibility(self, loaded_version):
return self.base_strategy.check_version_compatibility(loaded_version)
def distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
precomputed_distribution: Optional[ShardDistribution],
):
"""Applies the save distribution computed with `determine_main_replica_uniform_distribution`.
Based on rank assignment, sets replica ids of the shards saved by current rank to 0
and all the other replica ids to 1.
Args:
sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to
parallelization_group (ProcessGroup): distribution will be applied within this
process group. Must match with the process group passed to
`determine_main_replica_uniform_distribution`.
precomputed_distribution (ShardDistribution): distribution computed with
`determine_main_replica_uniform_distribution`
Returns: None
Example replica ids of tensors A, B, C before distribution:
rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0)
rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1)
rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)
Replicas after distribution for the example above:
rank0: A: 0, B: 1, C: 1
rank1: A: 1, B: 0, C: 1
rank2: A: 1, B: 1, C: 0
"""
if torch.distributed.get_world_size(group=parallelization_group) <= 1:
return
if precomputed_distribution is None:
raise ValueError(
'precomputed_distribution must be not None for non-trivial parallelization group'
)
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
rank_within_dp_group = torch.distributed.get_rank(parallelization_group)
for sh_ten in local_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
if (
shard_id in precomputed_distribution.shards_in_this_group
and rank_within_dp_group == precomputed_distribution.main_rank_for_shard[shard_id]
):
sh_ten.replica_id = 0
else:
sh_ten.replica_id = 1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
""" Performant resharding of flattened tensors.
Tensors that are first sharded (e.g. across TP) and then flattened cause
very irregular access patterns during loading. The idea for performant save/load
is to store tensors with global shape [X, Y, Z] and local shape [x, y, z]
as tensors with global shape [X // x, Y // y, Z // z, x * y * z] and
local shape [1, 1, 1, x * y * z]. This allows parallel save of tensors along the
last (flattened) dimension. During loading, some additional resharding is needed.
"""
import logging
import math
from dataclasses import dataclass
from itertools import product
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
from torch.distributed.checkpoint import ChunkStorageMetadata
from torch.distributed.checkpoint.resharding import _shards_get_overlap_region_wrt_saved_tensor
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
extract_matching_values,
)
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
)
logger = logging.getLogger(__name__)
@dataclass
class TensorReformulationMetadata:
"""Metadata needed to restore the original tensor shape.
Args:
ckpt_orig_global_shape (Tuple[int, ...]): original global shape of the tensor
saved in the checkpoint. This is the global shape of the application,
further reformulated into `ckpt_reform_global_shape` while saving.
ckpt_reform_global_shape (Tuple[int, ...]): reformulated global shape of the tensor
saved in the checkpoint. This is the actual saved shape.
"""
ckpt_orig_global_shape: Tuple[int, ...]
ckpt_reform_global_shape: Tuple[int, ...]
def __post_init__(self):
assert self.ckpt_orig_global_shape
def nd_flattened_tensor_reformulated_global_shape(sh_ten: ShardedTensor) -> Tuple[int, ...]:
"""Reformulated global shape of the flattened N-D ShardedTensor.
N-D tensor global shape [X, Y, Z] and local shape [x, y, z]
is reformulated into global shape [X // x, Y // y, Z // z, x * y * z] and
local shape [1, 1, 1, x * y * z], to allow parallel save of tensors along the
last (flattened) dimension.
Args:
sh_ten (ShardedTensor): flattened N-D ShardedTensor (N > 1)
Returns:
Tuple[int, ...]: reformulated tensor shape
"""
assert is_nd_flattened_tensor(sh_ten), sh_ten
return sh_ten.axis_fragmentations + (int(np.prod(sh_ten.local_shape)),)
def is_nd_flattened_tensor(sh_ten: Any) -> bool:
"""Checks if ShardedTensor is flattened and more than 1-dimensional
Args:
sh_ten (Any): any object
Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
"""
return (
isinstance(sh_ten, ShardedTensor)
and sh_ten.flattened_range is not None
and len(sh_ten.global_shape) > 1
)
# information needed to restore. With current implementation, this is a nested state dict
# with ShardedTensorFactories which is basically a ShardedStateDict type
ReformulationRestoreMetadata = ShardedStateDict
def apply_nd_flattened_tensors_reformulation(
sharded_state_dict: ShardedStateDict,
reformulation_metadata: Dict[str, TensorReformulationMetadata],
) -> Tuple[ShardedStateDict, ReformulationRestoreMetadata]:
"""Applies N-D reformulation to a given sharded state dict.
After applying the method and loading the reformulated state dict,
the `restore_nd_flattened_tensors_formulation` needs to be applied.
Current implementation uses ShardedTensorFactories for convenience of
restoring the original structure, but it's just an implementation detail.
Turns N-D ShardedTensors into factories and immediately applies them,
keeping the data needed to restore the original structure.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict potentially
with tensors to reformulate.
reformulation_metadata (Dict[str, TensorReformulationMetadata]): dict
containing all metadata needed for reformulating tensors in `sharded_state_dict`.
for each N-D flattened tensor `sh_ten` in `sharded_state_dict` there must be an
entry with `sh_ten.key`.
Returns:
tuple:
ShardedStateDict - reformulated sharded state dict
ReformulationRestoreMetadata - data needed to restore the original formulation
with `restore_nd_flattened_tensors_formulation`
"""
def maybe_reformulate_nd_flattened_tensor(sh_ten: Any):
if not isinstance(sh_ten, ShardedTensor) or not is_nd_flattened_tensor(sh_ten):
return sh_ten
# N-D flattened ShardedTensor
try:
sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key]
except KeyError as e:
raise CheckpointingException(
f'Missing reformulation metadata for tensor {sh_ten}. Existing keys: {reformulation_metadata.keys()}'
) from e
ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape
app_actual_load_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if ckpt_actual_saved_shape == app_actual_load_shape:
# Same shape - no need to reshard
return sh_ten
return reformulate_single_nd_flattened_tensor(sh_ten, sh_ten_reformulation_metadata)
# Turn N-D tensors into factories and immediately apply them
dict_list_map_inplace(maybe_reformulate_nd_flattened_tensor, sharded_state_dict)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Unlink `data` pointers to free memory
def unlink_data(x):
x.data = None
return x
dict_list_map_inplace(unlink_data, sh_ten_factories)
return sharded_state_dict, sh_ten_factories
def restore_nd_flattened_tensors_formulation(
state_dict: StateDict, formulation_restore_metadata: ReformulationRestoreMetadata
) -> StateDict:
"""Restores the original state dict from a reformulated form.
Inverse of `apply_nd_flattened_tensors_reformulation`.
Args:
state_dict (StateDict): state dict obtained by loading a reformulated
sharded state dict.
formulation_restore_metadata (ReformulationRestoreMetadata): metadata returned by
`apply_nd_flattened_tensors_reformulation` function
Returns:
StateDict: state dict with the original tensors formulation restored
"""
return apply_factory_merges(state_dict, formulation_restore_metadata)
def reformulate_single_nd_flattened_tensor(
sh_ten: ShardedTensor, reformulation_metadata: TensorReformulationMetadata
) -> Union[Any, ShardedTensorFactory]:
"""Reformulates shapes of a single N-D flattened ShardedTensor.
We need to define a pair of transformations:
- turn N-D ShardedTensor with original formulation into multiple reformulated ShardedTensors
- merge multiple reformulated loaded torch.Tensors into a single original tensor
Current implementation uses ShardedTensorFactories as a convenient mechanism
for specifying and keeping track of those transformations.
Args:
sh_ten (ShardedTensor): sharded tensor to reformulate.
reformulation_metadata (TensorReformulationMetadata): metadata needed to
perform the reformulation
Returns:
ShardedTensorFactory: factory that keeps information how to reformulate
(build) the ShardedTensor and then restore original formulation (merge)
after loading.
"""
rmd = reformulation_metadata
# Data won't be needed - remove unnecessary tensor references
sh_ten = sh_ten.without_data()
# Based on reformulation_metadata, determine other tensor shapes and metadata
ckpt_axis_fragmentation = rmd.ckpt_reform_global_shape[:-1]
for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation):
assert sh % fragm == 0, (sh_ten, rmd.ckpt_reform_global_shape)
ckpt_local_shape_with_prepended_axis = tuple(
sh // fragm for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation)
)
assert (
ckpt_local_shape_with_prepended_axis[: sh_ten.prepend_axis_num]
== (1,) * sh_ten.prepend_axis_num
), (ckpt_local_shape_with_prepended_axis, sh_ten)
ckpt_local_shape = ckpt_local_shape_with_prepended_axis[sh_ten.prepend_axis_num :]
# Iterate over reformulated shapes needed by the application and from checkpoint,
# and generate new ShardedTensors that match the checkpoint sharding.
overlap_dim_offsets = []
assert len(ckpt_axis_fragmentation) == len(sh_ten.axis_fragmentations), (
ckpt_axis_fragmentation,
sh_ten,
)
for dim, (app_chunk_dim_offset, ckpt_fragm, app_fragm) in enumerate(
zip(
sh_ten.local_chunk_offset_in_global(),
ckpt_axis_fragmentation,
sh_ten.axis_fragmentations,
)
):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset)
# `math.ceil` argument is an exact offset of the app next shard expressed in ckpt_local_shape units
next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1))
overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset))
logger.debug(
f'Generated the following number of overlap shards for each dimension: {list(map(len, overlap_dim_offsets))}'
f' for fragmentation ckpt {ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} and chunk offset {sh_ten.local_chunk_offset_in_global()}'
)
reformulated_sh_tens = {}
for chunk_offset in product(*overlap_dim_offsets):
global_offset = tuple(
chunk_off * chunk_shape
for chunk_off, chunk_shape in zip(chunk_offset, ckpt_local_shape_with_prepended_axis)
)
reformulated_sh_tens[(global_offset, ckpt_local_shape)] = ShardedTensor(
sh_ten.key,
None,
sh_ten.dtype,
ckpt_local_shape,
rmd.ckpt_orig_global_shape,
global_offset,
ckpt_axis_fragmentation,
sh_ten.replica_id,
sh_ten.prepend_axis_num,
sh_ten.allow_shape_mismatch,
flattened_range=slice(0, rmd.ckpt_reform_global_shape[-1]), # whole ckpt shard
)
# Now, we have to define the transformations from application sharding
# to checkpoint sharding.
@torch.no_grad()
def sh_ten_build_fn(*args, **kwargs):
# Here we simply return the precomputed tensors.
return reformulated_sh_tens
@torch.no_grad()
def sh_ten_merge_fn(sub_state_dict):
# This is the non-flattened local tensor with original formulation
# that we are going to fill with shards loaded from the checkpoint.
app_non_flat_ten = torch.empty(
sh_ten.local_shape,
dtype=sh_ten.dtype,
device=sh_ten.data.device if sh_ten.data is not None else None,
)
assert len(sub_state_dict) > 0
for (ckpt_global_offset, ckpt_local_shape), ckpt_ten in sub_state_dict.items():
# For each ckpt shard, we fill the appropriate application shard part
dest_ten = app_non_flat_ten
src_ten = ckpt_ten.view(ckpt_local_shape)
# We don't need narrowing over `prepend_axis_num` axes so we take the [sh_ten.prepend_axis_num:] offsets slice
for (
dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=ChunkStorageMetadata(
ckpt_global_offset[sh_ten.prepend_axis_num :], ckpt_local_shape
),
current_shard=ChunkStorageMetadata(
sh_ten.global_offset[sh_ten.prepend_axis_num :], sh_ten.local_shape
),
):
src_ten = src_ten.narrow(dim, offset_for_saved_tensor, length)
dest_ten = dest_ten.narrow(dim, offset_for_current_tensor, length)
dest_ten.copy_(src_ten)
return app_non_flat_ten.flatten()[sh_ten.flattened_range]
return ShardedTensorFactory(
sh_ten.key,
sh_ten.data,
sh_ten_build_fn,
sh_ten_merge_fn,
sh_ten.replica_id,
sh_ten.flattened_range,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" State dict saver for PyT Distributed format allowing asynchronous save. """
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import CheckpointException
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
logger = getLogger(__name__)
def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]:
"""
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver.
In order to support async save, saving should be split into three parts:
1. Planning
2. Actual saving
3. Finalization
Out of these, step (2) *must* happen asynchronously.
The first step is realized with this function.
The planning part consists of several steps, described here:
https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
Args:
state_dict (STATE_DICT_TYPE): state dict to save
storage_writer (FileSystemWriterAsync): in current version only an instance of
FileSystemWriterAsync
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional):
Each object of this tuple will be used in the order as following
cached_central_plan (SavePlan): a globally coordinated save plan
cached in the previous iteration
cached_local_plan (SavePlan): a local plan
cached in the previous iteration
validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict
is consistent over iterations
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
"""
cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False)
if cached_ckpt_structure:
cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
dist_wrapper = _DistWrapper(process_group, True, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan
def local_step():
nonlocal local_plan
assert planner is not None
# PyTorch 2.4 introduced additional `metadata` argument,
# we have to reference `is_coordinator` args by name
planner.set_up_planner(state_dict, is_coordinator=dist_wrapper.is_coordinator)
storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator)
if not validated_cache_reuse and local_plan is None:
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
# Execute local and global planning
start_plan = time()
if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable")
local_step()
central_plan = cached_central_plan
else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan)
end_plan = time()
logger.debug(f"rank: {rank}, plan time: {end_plan - start_plan}")
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start = time()
storage_writer.prepare_write_data(central_plan, planner)
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return (
(storage_writer, cast(Metadata, global_metadata), dist_wrapper),
central_plan,
local_plan,
cached_central_plan == central_plan,
)
def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
) -> None:
"""
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output,
the `write_results` are retrieved from the storage_writer.
Args:
storage_writer (FileSystemWriterAsync): storage writer used for planning
global_metadata (Metadata): metadata created during planning
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: None
"""
write_results = storage_writer.retrieve_write_results()
# Gather the write results that will be saved to the metadata file.
gather_start = time()
all_results = dist_wrapper.gather_object(write_results)
gather_end = time()
logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}")
# Store the metadata on coordinator rank
if dist_wrapper.is_coordinator:
node_failures = _get_failure_dict(all_results)
if len(node_failures) == 0:
assert global_metadata is not None
write_start = time()
storage_writer.finish(global_metadata, all_results)
write_end = time()
logger.debug(f"{write_end}, metadata_write: {write_end - write_start}")
else:
raise CheckpointException("write", node_failures)
......@@ -4,6 +4,7 @@
from functools import partial
from itertools import starmap
from logging import getLogger
from pathlib import Path
import tensorstore as ts
......@@ -12,13 +13,22 @@ import torch
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy, StrategyAction, default_strategies
from .zarr import postprocess_numpy_array
from .base import LoadShardedStrategy, StrategyAction, register_default_strategy
from .zarr import load_zarr_based_sharded_metadata, postprocess_numpy_array
_import_trigger = None
logger = getLogger(__name__)
def register_default_tensorstore_strategies():
"""Register default strategies leveraging tensorstore."""
register_default_strategy(
StrategyAction.LOAD_SHARDED, 'zarr', 1, TensorStoreLoadShardedStrategy()
)
class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
"""Load strategy for Zarr backend using `tensorstore` for loading."""
def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device
......@@ -36,6 +46,13 @@ class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
dict_list_map_inplace(load_fn, sharded_state_dict)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
......@@ -44,6 +61,8 @@ class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
def merge_global_slice_with_shape(global_slice, actual_shape, key):
"""Intersects the global slice with the actual shape (prevent overflow)."""
def _merge_slice(dim_slice, dim_size):
if isinstance(dim_slice, slice):
assert (
......@@ -74,18 +93,7 @@ def _load_from_array(
def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {
'driver': 'file',
'path': str(checkpoint_dir / sharded_tensor.key),
}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
except Exception as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} could not be loaded. Error: {e}'
) from e
arr = open_ts_array(checkpoint_dir / sharded_tensor.key)
if sharded_tensor.global_shape == arr.shape:
x = (
arr[sharded_tensor.global_slice()].read().result()
......@@ -105,6 +113,16 @@ def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
return x
default_strategies[StrategyAction.LOAD_SHARDED.value][
('zarr', 1)
] = TensorStoreLoadShardedStrategy()
def open_ts_array(arr_path: Path):
"""Opens a Zarr file array with Tensorstore with basic setting.
Args:
arr_path (Path): path to a Zarr (Tensorstore) array
"""
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {'driver': 'file', 'path': str(arr_path)}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
except Exception as e:
raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e
return arr
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import io
import os
import pickle
import warnings
from collections import ChainMap, defaultdict
from dataclasses import dataclass
from itertools import product
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import torch
from packaging.version import Version as PkgVersion
from torch.distributed import checkpoint
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import Shard
from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties
from torch.distributed.checkpoint import (
BytesStorageMetadata,
DefaultLoadPlanner,
DefaultSavePlanner,
FileSystemReader,
FileSystemWriter,
LoadPlan,
Metadata,
ReadItem,
SavePlan,
TensorStorageMetadata,
WriteItem,
)
from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict
from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.planner_helpers import _create_write_items
from ...utils import get_torch_version, is_torch_min_version
from ..core import CheckpointingException
from ..dict_utils import nested_values
from ..mapping import (
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
StateDict,
is_main_replica,
)
from .async_utils import AsyncRequest
from .base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
StrategyAction,
register_default_strategy,
)
from .filesystem_async import FileSystemWriterAsync
from .resharding import (
TensorReformulationMetadata,
apply_nd_flattened_tensors_reformulation,
is_nd_flattened_tensor,
nd_flattened_tensor_reformulated_global_shape,
restore_nd_flattened_tensors_formulation,
)
from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan
try:
if not torch.cuda.is_available():
raise ImportError
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
from torch.distributed._tensor import DTensor
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
_metadata_fn: str = ".metadata"
def register_default_torch_strategies():
"""Register default strategies related to PyT Distributed backend."""
register_default_strategy(
StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy()
)
register_default_strategy(
StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1)
)
logger = getLogger(__name__)
def flatten_state_dict(
state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]:
"""Flattens state dict into a single level dict.
It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict
which also accepts ShardedBase tensors as terminal objects
Args:
state_dict (ShardedStateDict): state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
"""
flattened = {}
mappings = {}
def flat_copy(path: OBJ_PATH, value: Any) -> None:
new_fqn = ".".join(map(str, path))
if new_fqn in flattened:
raise ValueError(f"duplicated flatten key {new_fqn}")
flattened[new_fqn] = value
mappings[new_fqn] = path
traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase)))
return flattened, mappings
def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
On high-level, this function follows the logic of
torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor.
Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore)
as attributes for further restoration in `_unwrap_pyt_sharded_tensor`.
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
This will need special handling while resharding.
Args:
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
"""
if rank is None:
rank = torch.distributed.get_rank()
some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1
for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous()
local_global_offsets = {}
prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs)
if is_flattened_range_1d:
# Type (2) case: 1D flattened ShardedTensors
for sh_ten in sh_tens:
assert len(sh_ten.global_offset) == 1, sh_ten
assert sh_ten.prepend_axis_num == 0, sh_ten
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
global_shape = some_sh_ten.global_shape
offsets_shape = (
some_sh_ten.local_shape
) # local shape is not flattened, we need it for chunk offsets
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
[
sh_ten.global_offset[0] + sh_ten.flattened_range.start
], # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
elif has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
sh_ten
)
assert sh_ten.data.ndim == 1, sh_ten
sh_ten.data = sh_ten.data.view((1,) * len(sh_ten.global_shape) + (-1,))
# Global shape reformulation:
global_shape = nd_flattened_tensor_reformulated_global_shape(some_sh_ten)
offsets_shape = (1,) * len(
some_sh_ten.global_shape
) # reformulated global shape has shape equal ti number of local chunks
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
list(
sh_ten.local_chunk_offset_in_global() + (sh_ten.flattened_range.start,)
), # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
else:
# Type (1) case: non-flat regular ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
sh_ten.data = sh_ten.data.view(
(1,) * prepend_axis_num + sh_ten.local_shape
) # adjust to prepended_axis_num
global_shape = some_sh_ten.global_shape
offsets_shape = some_sh_ten.data.shape # includes prepended axes
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data, list(sh_ten.global_offset), rank # simple case
)
for sh_ten in sh_tens
]
# Create a ShardedTensor without invoking communication. Determine global shards
world_size = torch.distributed.get_world_size()
shard_metadata = []
# NOTE: here we assume a regular grid of shards
for fragment_offsets in product(*map(range, some_sh_ten.axis_fragmentations)):
offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape)))
if offset in local_global_offsets:
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if is_flattened_range_1d:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
elif has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
offset = sh_ten.local_chunk_offset_in_global() + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = sh_ten.data.shape
shard_metadata.append(ShardMetadata(offset, size, placement))
else:
# pylint: disable=line-too-long
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range and not is_flattened_range_1d:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = offsets_shape
shard_metadata.append(ShardMetadata(offset, size, placement))
tensor = some_sh_ten.data
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=torch.Size(global_shape),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None
)
# Store MCore related data as PyTShardedTensor attribute.
# This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {}
if has_flattened_range and not is_flattened_range_1d:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten
def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
Args:
state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values
are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional): flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional): device to initialize potentially missing tensors
during loading. Defaults to 'cpu'.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values
converted either into PyT ShardedTensors or io.BytesIO.
"""
rank = torch.distributed.get_rank()
pyt_state_dict = {}
def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
"""Build a PyT ShardedTensor from given shards.
During loading:
- if data is None, initialize it with an empty tensor (will be used to copy the data into)
- if `allow_shape_mismatch` is True, the data is initialized with zeros
prior to loading (not all parts of the tensor will be read from the checkpoint)
"""
assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens
for sh_ten in sh_tens:
if sh_ten.data is None:
if is_loading:
sh_ten.init_data(
init_device,
init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty,
)
else:
raise CheckpointingException(f'`data` attr is None for {sh_ten}')
else:
sh_ten.data = sh_ten.data.detach()
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten
def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO:
"""Build io.BytesIO from given sharded objects data."""
assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs
serialized_data = io.BytesIO()
torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data)
return serialized_data
for k, v in state_dict.items():
if isinstance(v[0], ShardedTensor):
v = cast(List[ShardedTensor], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v)
else:
v = cast(List[ShardedObject], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_object(v)
return pyt_state_dict
def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]:
"""Unwrap tensor from PyT ShardedTensor instance.
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
then the tensor has additional singleton dimensions which should be squeezed.
"""
mcore_sh_ten = sh_ten.mcore_sh_ten
ret_tensors = []
for sh in sh_ten.local_shards():
ten = sh.tensor
if mcore_sh_ten.flattened_range is not None:
assert ten.shape[:-1] == (1,) * (len(ten.shape) - 1), ten.shape
ten = ten.view(-1)
else:
for _ in range(mcore_sh_ten.prepend_axis_num):
ten = ten.squeeze(0)
ret_tensors.append(ten)
return ret_tensors
def _replace_state_dict_keys_with_sharded_keys(
sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False
) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]:
"""Group ShardedBase objects by keys and
return mappings required for recreating the original dict."""
flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict)
rename_mapping = defaultdict(list)
new_flat_sd = defaultdict(list)
for k, sh_base in flat_sd.items():
assert isinstance(sh_base, ShardedBase), type(sh_base)
key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key
if is_main_replica(sh_base.replica_id) or not keep_only_main_replica:
rename_mapping[key].append(k)
new_flat_sd[key].append(sh_base)
return new_flat_sd, flat_mapping, rename_mapping
def _replace_sharded_keys_with_state_dict_keys(
state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]],
flat_mapping: FLATTEN_MAPPING,
rename_mapping: Dict[str, List[str]],
):
"""Inverse of _replace_state_dict_keys_with_sharded_keys."""
recovered_sd = {}
for k, tensors in state_dict.items():
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten
return unflatten_state_dict(recovered_sd, flat_mapping)
def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]):
"""Recursively update `x` keys, based on `keys_template`."""
if isinstance(keys_template, dict):
assert isinstance(x, dict), type(x)
for k, v in keys_template.items():
if not isinstance(k, str):
assert str(k) in x, (k, x.keys)
x[k] = x.pop(str(k))
_restore_dict_types(x[k], v)
elif isinstance(keys_template, list):
assert isinstance(x, list), type(x)
for x_val, templ_val in zip(x, keys_template):
_restore_dict_types(x_val, templ_val)
@dataclass(frozen=True)
class MCoreSavePlan(SavePlan):
"""SavePlan with MCore specific data."""
mcore_data: Dict[str, Dict[str, Any]] = None # Mcore related data about each tensor
class MCoreSavePlanner(DefaultSavePlanner):
"""Differs with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects
come from ShardedObjects, which should be treated as separate objects on each rank
(not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it
in transform_object.
"""
def __init__(
self,
*args,
dedup_replicated_tensors: Optional[bool] = None,
nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None,
**kwargs,
) -> None:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
# during saving.
if get_torch_version() <= PkgVersion("2.2"):
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
def create_local_plan(self) -> SavePlan:
"""Adds IOBytes write request on non-coordinator ranks."""
# NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because
# some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container)
# add iobytes request only on coordinator ranks and some alpha versions
# (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container)
# add those requests on all ranks. We inline a simplified version of this method below.
write_items = []
for fqn, obj in self.state_dict.items():
assert not HAVE_DTENSOR or not isinstance(
obj, DTensor
) # translation from MCore ShardedTensors shouldn't result in DTensors
# Create write requests for tensor and bytes values.
# For MCore, these should be already non-duplicates.
write_items += _create_write_items(fqn, obj)
self.plan = MCoreSavePlan(
items=write_items,
planner_data=self.mappings,
mcore_data={
k: sh_ten.mcore_metadata
for k, sh_ten in self.state_dict.items()
if isinstance(sh_ten, TorchShardedTensor)
},
)
return self.plan
def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]:
"""Merges MCore data for all plans."""
global_plan, metadata = super().create_global_plan(all_plans)
metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans)))
return global_plan, metadata
def transform_object(self, write_item: WriteItem, object: Any):
"""Make no transformations - bytes objects are already serialized."""
return object
class MCoreLoadPlanner(DefaultLoadPlanner):
"""Adds global shape validation to the default planner.
If global shape validation can be ignored (shouldn't!), the default
load planner can be used.
"""
def __init__(
self, *args, shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (), **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors
self._intermediate_read_item_and_target: Optional[Tuple[ReadItem, torch.Tensor]] = None
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
if sh_ten.key not in metadata.state_dict_metadata:
raise KeyError(
f"{sh_ten.key} from model not in state dict:"
f" {sorted(metadata.state_dict_metadata.keys())}"
)
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
if not is_nd_flattened_tensor(sh_ten):
expected_shape = sh_ten.global_shape
else:
expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if loaded_shape != expected_shape:
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor'
f' for key {sh_ten.key}'
)
raise CheckpointingException(_msg)
def create_local_plan(self) -> LoadPlan:
"""Runs additional shapes validation."""
self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors)
return super().create_local_plan()
def resolve_tensor(self, read_item: ReadItem):
"""Override to add FP8 support.
Narrowing the Float8Tensor can create incontiguous tensors and there are
no `copy` kernels for such cases. This method creates a contiguous FP8
tensors so that the subsequent `copy_` in FileSystemReader succeeds.
Note that this requires tracking the original tensor
(as `self._intermediate_read_item_and_target` attribute)
and restoring it in `commit_tensor` method.
"""
target_tensor = super().resolve_tensor(read_item)
if (
not target_tensor.is_contiguous()
and HAVE_TE
and isinstance(target_tensor, Float8Tensor)
):
self._intermediate_read_item_and_target = (read_item, target_tensor)
target_tensor = Float8Tensor.make_like(
target_tensor, data=target_tensor._data.contiguous()
)
return target_tensor
def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
"""Restores the original FP8 tensor saved in `resolve_tensor`."""
if self._intermediate_read_item_and_target is not None:
interm_read_item, target_tensor = self._intermediate_read_item_and_target
assert (
interm_read_item is read_item
), '`commit_tensor` method should be called right after `resolve_tensor`'
target_tensor.copy_(tensor)
tensor = target_tensor
self._intermediate_read_item_and_target = None
return super().commit_tensor(read_item, tensor)
class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
"""Async save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors
and use the async-adjusted torch.distributed.checkpoint saving mechanism
provided by the FileSystemWriterAsync writer.
"""
def __init__(
self,
backend: str,
version: int,
keep_only_main_replica: bool = True,
thread_count: int = 2,
cached_metadata: bool = False,
separation_hint: str = None,
):
"""Adds parameters specific to PyT Distributed format
Args:
backend (str): format backend string
version (int): format version
keep_only_main_replica (bool, optional): PyT Distributed has a mechanism
for deduplication, but replica_id aware deduplication is more coherent.
Default is True (recommended to keep it).
thread_count (int, optional): threads to use during saving.
Affects the number of files in the checkpoint (saving ranks * num_threads).
cached_metadata (bool, optional): Enables using cached global metadata to avoid
gathering local metadata every checkpointing invocation
separation_hint(str, optional): If provided, all tensors whose keys have this
prefix will be saved to a separate file.
"""
super().__init__(backend, version)
self.keep_only_main_replica = keep_only_main_replica
self.thread_count = thread_count
# Cached SavePlans to skip plan in `save_state_dict_async_plan`
# cached outcome of `SavePlan.prepare_global_plan`,
# which aggregates local plans from all ranks
self.cached_central_plan: SavePlan = None
# cached outcome of `SavePlan.prepare_local_plan` describes how local state_dict is written
self.cached_local_plan: SavePlan = None
# Cached global metadata, only `coordinator` for dist-ckpt holds
# if central plans are consistent over iters
self.cached_global_metadata: Metadata = None
# This variable records if the ckpt structures are consistent
# so the following checkpoint savings reuse `cached_global_metadata`
self.validated_cache_reuse: bool = False
# The knob to enable cached metadata communication in saving
self.use_cached_ckpt_structure: bool = cached_metadata
self.separation_hint = separation_hint
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
"""Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint directory
Returns: None
"""
# Translate the state dict
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(
sharded_state_dict, self.keep_only_main_replica
)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)
# Use PyT saving mechanism
writer = FileSystemWriterAsync(
checkpoint_dir, separation_hint=self.separation_hint, thread_count=self.thread_count
)
# This should be set differently if we run in a smaller process group than the default
coordinator = 0
# Try twice to validate the generated `central_plan` is the same across iterations
# If so, reuse `cached_central_plan` and `cached_global_metadata`
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
args_cached_plans = None
if self.use_cached_ckpt_structure:
args_cached_plans = (
self.cached_central_plan,
self.cached_local_plan,
self.validated_cache_reuse,
)
(
save_state_dict_ret,
self.cached_central_plan,
self.cached_local_plan,
self.validated_cache_reuse,
) = save_state_dict_async_plan(
pyt_state_dict,
writer,
None,
coordinator,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica),
cached_ckpt_structure=args_cached_plans,
)
rank = torch.distributed.get_rank()
if self.use_cached_ckpt_structure:
if self.validated_cache_reuse:
logger.debug(f"rank: {rank}, cache validated")
if save_state_dict_ret[1]: # when global_metadata is not cached
self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata
# Only Coordinator rank holds cached global_metadata
# (None is returned for global_metadata)
elif coordinator == rank:
logger.debug(f"rank: {rank}, reuse metadata, {save_state_dict_ret[1]}")
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
return self._get_save_and_finalize_callbacks(writer, save_state_dict_ret)
def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest:
save_fn_args = writer.get_save_function_and_args()
save_fn, save_args = save_fn_args
def finalize_fn():
save_state_dict_async_finalize(*save_state_dict_ret)
torch.distributed.barrier()
return AsyncRequest(save_fn, save_args, [finalize_fn])
def can_handle_sharded_objects(self):
return True
def get_reformulation_metadata(
sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> Dict[str, TensorReformulationMetadata]:
"""Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory
Returns:
Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every
N-D flattened tensor from the sharded_state_dict to its original global shape
as stored in `mcore_data` in the checkpoint.
"""
ckpt_metadata = FileSystemReader(checkpoint_dir).read_metadata()
reformulation_metadata = {}
for sh_ten in nested_values(sharded_state_dict):
if not is_nd_flattened_tensor(sh_ten):
continue
try:
ckpt_global_shape = ckpt_metadata.mcore_data[sh_ten.key][
'nd_reformulated_orig_global_shape'
]
except KeyError as e:
raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
) from e
reformulation_metadata[sh_ten.key] = TensorReformulationMetadata(
ckpt_global_shape, ckpt_metadata.state_dict_metadata[sh_ten.key].size
)
return reformulation_metadata
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format."""
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
# Apply N-D tensors resharding
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
)
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch
]
orig_sharded_state_dict = sharded_state_dict
# MCore state dict to PyT Distributed compatible
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
# Load PyT Distributed format
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict = {
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
_restore_dict_types(mcore_state_dict, orig_sharded_state_dict)
# Apply N-D tensors resharding postprocessing
mcore_state_dict = restore_nd_flattened_tensors_formulation(
mcore_state_dict, formulation_restore_data
)
return mcore_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None):
"""Uses tensors metadata stored in the metadata file."""
if metadata is None:
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
mcore_data = getattr(metadata, 'mcore_data', {})
sharded_metadata = {}
for k, tp in metadata.state_dict_metadata.items():
if not isinstance(tp, TensorStorageMetadata):
continue # load only tensors
nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape')
if nd_orig_global_shape is None:
# Regular tensor
sharded_metadata[k] = ShardedTensor.from_rank_offsets(
k, torch.empty(tp.size, **tp.properties.__dict__, device='meta')
).without_data()
else:
# N-D flattened tensor
unflat_ten = torch.empty(
nd_orig_global_shape, **tp.properties.__dict__, device='meta'
)
flat_ten = unflat_ten.flatten()
sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat(
k,
flat_ten,
unflat_ten.shape,
flattened_range=slice(0, unflat_ten.numel()), # whole slice
).without_data()
return sharded_metadata
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
"""Uses tensors and objects metadata stored in the metadata file."""
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
sharded_metadata = {}
for metadata_key, storage_metadata in metadata.state_dict_metadata.items():
if not isinstance(storage_metadata, BytesStorageMetadata):
continue
sh_obj = ShardedObject.empty_from_unique_key(metadata_key)
sharded_metadata[sh_obj.unique_key] = sh_obj
sharded_metadata.update(self.load_tensors_metadata(checkpoint_dir, metadata))
return sharded_metadata
def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str):
"""Removes checkpoint files whose keys have the given prefix.
Performs the following steps:
1. checks whether there are files that start with the key_prefix
2. loads metadata
3. removes all entries from the metadata that start with the key_prefix
4. resaves the new metadata and removes the old metadata
5. removes the relevant files
"""
assert is_torch_min_version(
"2.3.0"
), f'torch >= 2.3.0 is required for remove_sharded_tensors'
distckpt_files = [f for f in os.listdir(checkpoint_dir) if f.endswith("distcp")]
files_to_remove = [f for f in distckpt_files if f.startswith(key_prefix)]
if not files_to_remove:
warnings.warn(
f'There are no files in {checkpoint_dir} that begin with "{key_prefix}".'
f' Skipping removal.'
)
return
fs_reader = FileSystemReader(checkpoint_dir)
original_metadata = fs_reader.read_metadata()
new_state_dict_metadata = {}
new_planner_data = {}
new_storage_data = {}
for k in original_metadata.state_dict_metadata.keys():
if k.startswith(key_prefix):
continue
new_state_dict_metadata[k] = original_metadata.state_dict_metadata[k]
for k in original_metadata.planner_data.keys():
if k.startswith(key_prefix):
continue
new_planner_data[k] = original_metadata.planner_data[k]
for k in original_metadata.storage_data.keys():
if k.fqn.startswith(key_prefix):
continue
new_storage_data[k] = original_metadata.storage_data[k]
metadata = Metadata(
state_dict_metadata=new_state_dict_metadata,
planner_data=new_planner_data,
storage_data=new_storage_data,
)
fs_writer = FileSystemWriter(checkpoint_dir)
metadata_filename = cast(Path, fs_writer.fs.concat_path(fs_writer.path, _metadata_fn))
tmp_path = cast(
metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.tmp")
)
old_path = cast(
metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.bck")
)
## save the new metadata
with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
try:
os.fsync(metadata_file.fileno())
except AttributeError:
os.sync()
## move the old metadata
fs_writer.fs.rename(fs_writer.metadata_path, old_path)
try:
## rename the new metadata
fs_writer.fs.rename(tmp_path, fs_writer.metadata_path)
## finally, remove the files we want to drop
for f in files_to_remove:
fs_writer.fs.rm_file(checkpoint_dir / f)
except Exception as e:
fs_writer.fs.rename(old_path, fs_writer.metadata_path)
raise e
else:
fs_writer.fs.rm_file(old_path)
def can_handle_sharded_objects(self):
return True
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
......@@ -17,8 +17,8 @@ import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict
from .base import LoadShardedStrategy
from .tensorstore import _load_from_array
from .zarr import flatten_range
from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array
from .zarr import flatten_range, load_zarr_based_sharded_metadata
_import_trigger = None
......@@ -59,14 +59,11 @@ class _ShardedTensorMetadata:
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
return (
sharded_tensor.key,
sharded_tensor.global_offset,
)
return (sharded_tensor.key, sharded_tensor.global_offset)
class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
""" Loads one checkpoint replica from storage and broadcasts to other nodes.
"""Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
......@@ -77,19 +74,18 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
a) on one of the ranks load it from storage to CPU and move to CUDA
b) allocate CUDA tensor on other ranks
c) broadcast within DP group
d) copy tensor content to the model param location
e) free tensor buffers from a) and b)
3.a) on one of the ranks load it from storage to CPU and move to CUDA
3.b) allocate CUDA tensor on other ranks
3.c) broadcast within DP group
3.d) copy tensor content to the model param location
3.e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
a) loading from storage to numpy
b) moving CPU tensors to CUDA
c) broadcast
2.a) loading from storage to numpy
2.b) moving CPU tensors to CUDA
2.c) broadcast
"""
def __init__(self, data_parallel_group, cpu_transfer=True):
......@@ -108,12 +104,14 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
self.summarize_load_times()
# TODO: fix hang in summarize_load_times
# self.summarize_load_times()
return sharded_state_dict
def summarize_load_times(self):
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
for key, times in sorted(timers.items()):
times_sum = sum(times)
max_times = torch.tensor([times_sum], device='cuda')
......@@ -176,7 +174,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed()
def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]):
""" Group tensors by chunk and then pick the tensor with the lowest rank.
"""Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
......@@ -247,3 +245,10 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
return sharded_tensor.data
dict_list_map_inplace(_fill_in_data, sharded_state_dict)
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import logging
import os
from functools import partial
from logging import getLogger
from pathlib import Path
from typing import List
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import zarr
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
from .base import (
LoadShardedStrategy,
SaveShardedStrategy,
StrategyAction,
register_default_strategy,
)
logger = logging.getLogger(__name__)
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
np.dtype('bool'): torch.bool,
np.dtype('uint8'): torch.uint8,
np.dtype('int8'): torch.int8,
np.dtype('int16'): torch.int16,
np.dtype('int32'): torch.int32,
np.dtype('int64'): torch.int64,
np.dtype('float16'): torch.float16,
np.dtype('float32'): torch.float32,
np.dtype('float64'): torch.float64,
np.dtype('complex64'): torch.complex64,
np.dtype('complex128'): torch.complex128,
}
torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()}
try:
import tensorstore
# Register a bfloat16 type with this import
import tensorstore # pylint: disable=unused-import
HAS_BFLOAT16 = True
numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16
......@@ -41,11 +51,28 @@ try:
except ImportError:
HAS_BFLOAT16 = False
_import_trigger = None
logger = getLogger(__name__)
def register_default_zarr_strategies():
"""Register default strategies related to Zarr backend."""
register_default_strategy(
StrategyAction.SAVE_SHARDED, 'zarr', 1, ZarrSaveShardedStrategy('zarr', 1)
)
class ZarrSaveShardedStrategy(SaveShardedStrategy):
def save(self, sharded_tensors: List[ShardedTensor], checkpoint_dir: Path):
"""Save strategy for Zarr backend."""
def __init__(self, backend: str, version: int):
super().__init__(backend, version)
logger.warning(
f'`zarr` distributed checkpoint backend is deprecated.'
' Please switch to PyTorch Distributed format (`torch_dist`).'
)
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
sharded_tensors = list(nested_values(sharded_state_dict))
arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)
for ten, arr in zip(sharded_tensors, arrays):
_save_to_existing_array(ten, arr)
......@@ -54,24 +81,41 @@ class ZarrSaveShardedStrategy(SaveShardedStrategy):
def _create_or_open_zarr_arrays(
sharded_tensors: List[ShardedTensor], checkpoint_dir: Path
) -> List[zarr.Array]:
) -> List[Optional[zarr.Array]]:
"""Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk,
opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank
that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays = []
for ten in sharded_tensors:
if _should_create_array(ten):
_create_zarr_array(ten, checkpoint_dir)
# TODO: maybe reuse the opened arrays
arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None
arrays.append(arr)
torch.distributed.barrier()
for ten in sharded_tensors:
# if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}:
# continue
# Open arrays created above by other processes
for arr_idx, ten in enumerate(sharded_tensors):
if arrays[arr_idx] is not None:
# array created by this process
assert _should_create_array(ten), ten
continue
if not is_main_replica(ten.replica_id):
# this array won't be needed for saving and can stay None
continue
open_kwargs = {}
if ten.flattened_range is not None:
open_kwargs['synchronizer'] = zarr.ProcessSynchronizer(
str(checkpoint_dir / f'{ten.key}.sync')
)
arr = zarr.open(checkpoint_dir / ten.key, 'r+', **open_kwargs)
arrays.append(arr)
arrays[arr_idx] = _open_zarr_array_verbose(checkpoint_dir / ten.key, 'r+', **open_kwargs)
return arrays
......@@ -83,9 +127,10 @@ def _should_create_array(ten: ShardedTensor):
)
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: zarr.Array):
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]):
if not is_main_replica(sharded_tensor.replica_id):
return
assert arr is not None
x = sharded_tensor.data
x = x.detach().cpu()
torch.cuda.synchronize()
......@@ -114,6 +159,7 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
fill_value=None,
write_empty_chunks=True,
)
logger.debug(f'Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}')
except zarr.errors.ContainsArrayError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} already exists'
......@@ -127,12 +173,21 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
class ZarrLoadShardedStrategy(LoadShardedStrategy):
"""Load strategy for the Zarr backend."""
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
dict_list_map_inplace(
partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict
)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_zarr_shape_dtype(path):
arr = zarr.open(path, 'r')
return arr.shape, arr.dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_zarr_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
......@@ -142,12 +197,7 @@ class ZarrLoadShardedStrategy(LoadShardedStrategy):
def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
try:
arr = zarr.open(checkpoint_dir / sharded_tensor.key, 'r')
except zarr.errors.PathNotFoundError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} not found'
) from e
arr = _open_zarr_array_verbose(checkpoint_dir / sharded_tensor.key, 'r')
if not sharded_tensor.allow_shape_mismatch and sharded_tensor.global_shape != arr.shape:
_msg = (
......@@ -161,7 +211,22 @@ def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
return postprocess_numpy_array(x, sharded_tensor)
def _open_zarr_array_verbose(path: Path, mode: str, **open_kwargs):
try:
return zarr.open(str(path), mode, **open_kwargs)
except zarr.errors.PathNotFoundError as e:
ckpt_dir = path.parent
err_msg = f'Array {path} not found'
if ckpt_dir.exists():
ckpt_files = [f.name for f in ckpt_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory {ckpt_dir} content: {ckpt_files}')
else:
err_msg += f'. Checkpoint directory {ckpt_dir} does not exist.'
raise CheckpointingException(err_msg) from e
def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True):
"""Turn numpy array to torch tensor."""
x = loaded_array
if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'):
x = x.astype(np.dtype('float32'))
......@@ -189,10 +254,12 @@ def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=
def flatten_range(sharded_tensor, x):
"""Apply flattened range to a tensor."""
return x.flatten()[sharded_tensor.flattened_range]
def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
"""Pad tensor to the expected shape."""
pad_args = []
assert len(x.shape) == len(expected_sharded_ten.local_shape)
# Reversed iteration order because F.pad expects so
......@@ -204,9 +271,10 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
if x_sh == exp_sh:
pad_args.extend((0, 0))
elif x_sh > exp_sh:
assert (
False
), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}'
assert False, (
f'Expected shape ({exp_sh}) smaller than actual ({x_sh})'
f' for {repr(expected_sharded_ten)}'
)
else:
pad_args.extend((0, exp_sh - x_sh))
# TODO: behavior control with envvar is for testing purposes only, remove it
......@@ -224,7 +292,30 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0)
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy(
'zarr', 1
)
def load_zarr_based_sharded_metadata(
checkpoint_dir: Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], np.dtype]]
) -> ShardedStateDict:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir() or not (subdir / '.zarray').exists():
continue
key = subdir.name
arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
sharded_state_dict[key] = ShardedTensor(
key,
None,
numpy_to_torch_dtype_dict[arr_dtype],
arr_shape,
arr_shape,
tuple(0 for _ in arr_shape),
tuple(1 for _ in arr_shape),
)
return sharded_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
""" Helpers for manipulating sharded tensors and sharded state dicts. """
from typing import Dict, Optional, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
LocalNonpersitentObject,
LocalNonpersistentObject,
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId = Tuple[str, tuple, Optional[tuple]]
def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId:
"""Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range = sharded_tensor.flattened_range
return (
sharded_tensor.key,
sharded_tensor.global_offset,
None if f_range is None else (f_range.start, f_range.stop),
)
def _sharded_object_id(sharded_object: ShardedObject) -> _ShardId:
"""Unique id of the sharded object data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_object (ShardedObject): sharded object representing the data shard
Returns (tuple): unique id of a data shard
"""
return (sharded_object.key, sharded_object.global_offset, sharded_object.global_shape)
def extract_sharded_tensors(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedTensor objects
from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor
(keeping the original state dict structure)
"""
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor))
def extract_sharded_tensors_and_factories(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects
from a given state dict with any objects.
Args:
sharded_state_dict:
state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory))
)
......@@ -29,16 +93,127 @@ def extract_sharded_tensors_and_factories(
def extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict,
lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)),
lambda v: isinstance(v, (ShardedTensor, LocalNonpersistentObject, ShardedTensorFactory)),
)
def extract_sharded_base(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only ShardedBase from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedBase objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedBase objects (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase))
def extract_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
"""Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
Args:
sharded_state_dict: state dict possibly containing LocalNonpersistentObjects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all LocalNonpersistentObjects
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, LocalNonpersistentObject)
)
def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str):
"""Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def add_prefix(t):
if isinstance(t, ShardedTensor):
t.key = f'{prefix}.{t.key}'
if isinstance(t, ShardedBase):
t.key = f'{prefix}{t.key}'
return t
dict_list_map_inplace(add_prefix, sharded_state_dict)
def replace_prefix_for_sharding(
sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str
):
"""Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def _replace_prefix(x):
if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
if not x.key.startswith(old_prefix):
raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}')
x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
return x
dict_list_map_inplace(_replace_prefix, sharded_state_dict)
def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]):
"""Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]):
map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def _replace_prefixes(x):
if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
return x
for old_prefix, new_prefix in prefix_map.items():
if x.key.startswith(old_prefix):
x.key = (
f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
)
break
return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from collections import Counter, defaultdict
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
import numpy as np
import torch
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException, maybe_load_config
from megatron.core.dist_checkpointing.dict_utils import (
diff,
extract_matching_values,
map_reduce,
nested_values,
)
from megatron.core.dist_checkpointing.mapping import (
CommonStateDict,
ShardedBase,
ShardedObject,
ShardedStateDict,
is_main_replica,
)
from megatron.core.dist_checkpointing.strategies.base import (
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
if TYPE_CHECKING:
from megatron.core.dist_checkpointing.serialization import CkptShardedMetadata
logger = logging.getLogger(__name__)
# pylint: disable=line-too-long
# list of local saved/loaded ShardedBase objects
_LocalMetadata = List[Union[ShardedTensor, ShardedObject]]
# list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank)
_GlobalMetadata = List[_LocalMetadata]
class StrictHandling(Enum):
"""Determines handling of load mismatch (non-empty "unexpected" or "missing" keys).
Different flags carry different implications on performance and behaviour and
are divided into two groups:
- *_UNEXPECTED
- *_ALL
The first group ignores missing keys (present in the checkpoint but missing
in the sharded state dict) which is created in order to avoid inter-rank
metadata exchange. Note that the metadata exchange will happen anyway
with `load(..., validate_access_integrity=True)` flag in which case using the
`*_ALL` option is recommended as it provides a more thorough check with no
performance penalty wrt. `*_UNEXPECTED` group.
All options except for the first one (`ASSUME_OK_UNEXPECTED`) require
extra disk access before the load in order to remove unexpected keys
from the sharded state dict requested to load.
"""
# Relies on the underlying strategy to raise error on unexpected keys
ASSUME_OK_UNEXPECTED = 'assume_ok_unexpected'
# Logs (with WARNING level) "unexpected" keys. Missing keys are ignored.
# This is treated as a reasonable default for a "non-strict" load
LOG_UNEXPECTED = 'log_unexpected'
# Logs (with WARNING level) all mismatched keys.
LOG_ALL = 'log_all'
# Raise error on unexpected keys before load attempt.
# Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires
# extra disk access.
RAISE_UNEXPECTED = 'raise_unexpected'
# Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires
# metadata exchange.
RAISE_ALL = 'raise_all'
# "Unexpected" mismatches are not reported, but returned by the `load`
# function along with the loaded state dict. Missing keys are ignored.
RETURN_UNEXPECTED = 'return_unexpected'
# All mismatches are returned along with the loaded state dict.
RETURN_ALL = 'return_all'
# Simply ignores mismatches (not recommended)
IGNORE_ALL = 'ignore_all'
@staticmethod
def requires_explicit_ckpt_mismatch_check(val: 'StrictHandling') -> bool:
"""Whether a given strict flag involves mismatch check against the checkpoint."""
return val != StrictHandling.ASSUME_OK_UNEXPECTED
@staticmethod
def requires_global_app_metadata(val: 'StrictHandling') -> bool:
"""Whether a given strict option requires global metadata for validation."""
return val in (
StrictHandling.IGNORE_ALL,
StrictHandling.RAISE_ALL,
StrictHandling.RETURN_ALL,
StrictHandling.LOG_ALL,
)
@staticmethod
def requires_returning_mismatch_keys(val: 'StrictHandling') -> bool:
"""Whether a given strict option results in extra return value from the `load` function."""
return val in (StrictHandling.RETURN_UNEXPECTED, StrictHandling.RETURN_ALL)
def parse_strict_flag(strict: Union[str, StrictHandling]) -> StrictHandling:
"""Parse user passed strict flag from a string to StrictHandling instance.
Args:
strict (str, StrictHandling): strict flag to parse. If already an instance
of StrictHandling, this function is a noop.
Returns:
StrictHandling: enum instance
"""
if isinstance(strict, StrictHandling):
return strict
try:
return StrictHandling(strict)
except (ValueError, TypeError) as e:
raise ValueError(f'Invalid strict flag: {e}') from e
def validate_integrity_and_strict_load(
sharded_state_dict: ShardedStateDict,
strict: StrictHandling,
validate_access_integrity: bool,
local_metadata: Optional[_LocalMetadata] = None,
global_metadata: Optional[_GlobalMetadata] = None,
ckpt_sharded_metadata: Optional['CkptShardedMetadata'] = None,
) -> Tuple[ShardedStateDict, Set[str], Set[str]]:
"""Validates sharding integrity and potential mismatches with the checkpoint.
`validate_access_integrity` controls sharding integrity check (orthogonal
to strictness checking) which verifies `sharded_state_dict` runtime completeness
(in isolation from the actual checkpoint).
`strict` flag controls handling of mismatches between the requested
sharded state dict to load and the actual checkpoint. See `StrictHandling`
docs for details regarding flag behavior and performance implications
(disk interactions or inter-rank communication).
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to verify.
strict (StrictHandling): flag determining how to handle sharded keys mismatch.
validate_access_integrity (bool): whether to perform sharding validation.
local_metadata (_LocalMetadata, optional): local sharded state dict metadata.
Defaults to None, in which case it's determined based on `sharded_state_dict`.
global_metadata (_GlobalMetadata, optional): global sharded state dict metadata
(exchanged between ranks). Defaults to None, in which case "missing"
keys are not determined.
ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata
from the checkpoint. Defaults to None, which only makes sense
for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value.
Returns:
Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict
without unexpected keys, missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. Additionally,
missing keys might be erroneously empty (depending on `strict` value).
"""
missing_keys, unexpected_keys = [], []
if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
if ckpt_sharded_metadata is None:
raise CheckpointingException(
'Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None.'
)
if local_metadata is None:
local_metadata = [
sh_base.without_data() for sh_base in nested_values(sharded_state_dict)
]
# We don't want to check for missing keys even if we could
_skip_missing_keys = strict in (
StrictHandling.ASSUME_OK_UNEXPECTED,
StrictHandling.LOG_UNEXPECTED,
StrictHandling.RAISE_UNEXPECTED,
StrictHandling.RETURN_UNEXPECTED,
)
missing_keys, unexpected_keys = _determine_missing_and_unexpected_keys(
ckpt_sharded_metadata, local_metadata, None if _skip_missing_keys else global_metadata
)
sharded_state_dict = adjust_non_strict_load(sharded_state_dict, unexpected_keys)
if strict == StrictHandling.IGNORE_ALL:
missing_keys, unexpected_keys = [], []
elif strict in (StrictHandling.RAISE_UNEXPECTED, StrictHandling.RAISE_ALL):
maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, True)
elif strict in (StrictHandling.LOG_UNEXPECTED, StrictHandling.LOG_ALL):
maybe_report_missing_and_unexpected_keys(missing_keys, unexpected_keys, False)
if validate_access_integrity:
if global_metadata is None:
raise CheckpointingException(
'Cannot check sharding intergrity without global_metadata (None).'
)
validate_sharding_integrity(global_metadata)
return sharded_state_dict, missing_keys, unexpected_keys
def verify_checkpoint_and_load_strategy(
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]:
"""Verifies if checkpoint metadata exists and matches given strategies.
If no strategies are passed, they are determined based on the checkpoint metadata.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified
if compatible with the checkpoint content. If None, the default sharded load strategy
for the checkpoint backend will be returned.
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
if compatible with the checkpoint content. If None, the default common load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
if sharded_strategy is None:
sharded_strategy = get_default_strategy(
StrategyAction.LOAD_SHARDED,
saved_config.sharded_backend,
saved_config.sharded_backend_version,
)
elif isinstance(sharded_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)
if common_strategy is None:
common_strategy = get_default_strategy(
StrategyAction.LOAD_COMMON,
saved_config.common_backend,
saved_config.common_backend_version,
)
elif isinstance(common_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_COMMON, *common_strategy)
sharded_strategy.check_backend_compatibility(saved_config.sharded_backend)
sharded_strategy.check_version_compatibility(saved_config.sharded_backend_version)
common_strategy.check_backend_compatibility(saved_config.common_backend)
common_strategy.check_version_compatibility(saved_config.common_backend_version)
return sharded_strategy, common_strategy
def adjust_non_strict_load(
sharded_state_dict: ShardedStateDict, sharded_keys_to_remove: Set[str]
) -> ShardedStateDict:
"""Adjusts sharded state dict removing keys not existing in the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to modify
sharded_keys_to_remove (Set[str]): keys to remove from the state dict
Returns:
ShardedStateDict: state dict without ShardedBase objects with specified keys
"""
def is_unexpected_key(x: ShardedBase):
assert isinstance(x, ShardedBase), f'Unexpected type {type(x)}'
return x.key in sharded_keys_to_remove
_, sharded_state_dict = extract_matching_values(sharded_state_dict, is_unexpected_key)
return sharded_state_dict
def _determine_missing_and_unexpected_keys(
ckpt_sharded_metadata: 'CkptShardedMetadata',
local_metadata: _LocalMetadata,
global_metadata: Optional[_GlobalMetadata] = None,
) -> Tuple[Set[str], Set[str]]:
"""Determines load mismatches based on metadata.
There is an asymmetry between "unexpected" and "missing" keys.
Unexpected keys can be determined based only on local metadata.
Missing keys must be based on global metadata, since other ranks might access
different keys than the current rank.
In consequence, the return value of this function is different on each rank:
"missing_keys" are equal, but "unexpected_keys" might differ across ranks.
Args:
ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data)
constructed based on the checkpoint content
local_metadata (_LocalMetadata): list of local ShardedBase objects
requested to be loaded by this rank
global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects
requested to be loaded by all ranks. Defaults to None, in which case
returned "missing" keys are empty.
Returns:
Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. If passed
`global_metadata` is empty, returned missing keys are empty as well.
"""
local_accessed_keys = set(sh_base.key for sh_base in local_metadata)
ckpt_keys = set(sh_base.key for sh_base in ckpt_sharded_metadata.values())
unexpected_keys = local_accessed_keys - ckpt_keys
if global_metadata is not None:
global_accessed_keys = set(
sh_base.key for rank_metadata in global_metadata for sh_base in rank_metadata
)
missing_keys = ckpt_keys - global_accessed_keys
else:
missing_keys = set()
if missing_keys:
logger.debug(f'Dist ckpt load missing keys: {missing_keys}')
if unexpected_keys:
logger.debug(f'Dist ckpt load unexpected keys: {unexpected_keys}')
return missing_keys, unexpected_keys
def maybe_report_missing_and_unexpected_keys(
missing_keys: Set[str], unexpected_keys: Set[str], raise_error: bool = True
) -> None:
"""Raises or logs an error in case missing or unexpected keys are non-empty.
Args:
missing_keys (Set[str]): missing keys in the state dict
unexpected_keys (Set[str]): unexpected keys in the state dict
raise_error: If True, raises error on mismatch. Otherwise, logs mismatch
with WARNING level.
Returns:
None
Raises:
CheckpointingException: if `raise_error` is True and at least one of
`missing_keys` or `unexpected_keys` are non-empty.
"""
if not missing_keys and not unexpected_keys:
return
missing_title_msg = (
f'Some keys found in the checkpoint are missing in the provided sharded state dict. '
)
missing_body_msg = f'Missing keys (for all ranks): {missing_keys}. '
unexpected_title_msg = f'Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. '
unexpected_body_msg = f'Unexpected keys (for this rank): {unexpected_keys}. '
error_msg = ''
if missing_keys:
error_msg += missing_title_msg
if unexpected_keys:
error_msg += unexpected_title_msg
error_msg += '\n'
if missing_keys:
error_msg += missing_body_msg
if unexpected_keys:
error_msg += unexpected_body_msg
if raise_error:
raise CheckpointingException(error_msg)
else:
logger.warning(error_msg)
def _validate_common_state_dict(common_state_dict: CommonStateDict) -> None:
"""Validate consistancy across ranks for the common state dict
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
Args:
common_state_dict: The common state dict present in all ransk
"""
# Gather the common state dict across ranks onto rank 0 for comparison
rank = torch.distributed.get_rank()
other_rank_state_dicts = [None] * torch.distributed.get_world_size() if rank == 0 else None
torch.distributed.gather_object(common_state_dict, other_rank_state_dicts)
common_state_dict_diff = {}
if rank == 0:
main_rank_state_dict = common_state_dict
for rank, rank_state_dict in enumerate(other_rank_state_dicts[1:], 1):
only_left, only_right, mismatch = diff(main_rank_state_dict, rank_state_dict)
if only_left or only_right or mismatch:
common_state_dict_diff[rank] = (only_left, only_right, mismatch)
if len(common_state_dict_diff) != 0:
logger.warning(
f'There is difference in the common state dict in different ranks. The differences are {common_state_dict_diff}'
)
def validate_sharding_integrity(
global_metadata: _GlobalMetadata, common_state_dict: CommonStateDict = None
) -> None:
"""Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
common_state_dict (CommonStateDict): The common state dict stored by rank 0
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
if common_state_dict:
_validate_common_state_dict(common_state_dict)
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(global_metadata):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
if isinstance(shardings[0][1], ShardedObject):
_validate_objects_for_key(shardings)
else:
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
)
shard_access_cnt = _compute_shards_access(rank_sharding)
if has_flattened_range:
map_reduce(
rank_sharding,
lambda x: x[1].global_offset,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
"""Ensure uniqueness of saved objects."""
unique_keys = [
sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
if len(unique_keys) != expected_shard_num:
err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
logger.error(f'{err_msg} Existing shards: {unique_keys}')
raise CheckpointingException(err_msg)
def determine_global_metadata(
sharded_state_dict: ShardedStateDict,
) -> Tuple[_LocalMetadata, _GlobalMetadata]:
"""Exchanges local metadata with `all_gather_object` to determine global metadata.
Args:
sharded_state_dict (ShardedStateDict): local sharded state dict
Returns:
Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data
"""
local_metadata = [ten.without_data() for ten in nested_values(sharded_state_dict)]
global_metadata = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(global_metadata, local_metadata)
return local_metadata, global_metadata
def validate_sharded_objects_handling(
sharded_strategy: Union[SaveShardedStrategy, LoadShardedStrategy],
common_strategy: Union[SaveCommonStrategy, LoadCommonStrategy],
) -> None:
"""Checks if either of the passed strategies can handle sharded objects.
Args:
sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading
common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading
Returns:
None
Raises:
CheckpointingException: if both strategies can't handle ShardedObjects
"""
if (
not sharded_strategy.can_handle_sharded_objects
and not common_strategy.can_handle_sharded_objects
):
raise CheckpointingException(
f'Either sharded strategy or common strategy must implement ShardedObjects handling.'
f' Both {sharded_strategy} and {common_strategy} specify can_handle_sharded_objects=False'
)
## How to use pytorch FSDP2?
Add these flag to enable Torch FSDP2.
```
--use-torch-fsdp2
--no-gradient-accumulation-fusion
--ckpt-format torch_dist
```
It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from packaging.version import Version
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager
import torch
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
class _BaseDataParallel(MegatronModule):
"""A template class for DistributedDataParallel implementations."""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super().__init__(config=config)
self.module = module
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
try:
yield
finally:
pass
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
pass
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
pass
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
pass
def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from contextlib import contextmanager
import torch
from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank
from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
logger = logging.getLogger(__name__)
class DistributedDataParallel(_BaseDataParallel):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
):
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * parallel_state.get_data_parallel_world_size()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if parallel_state.get_pipeline_model_parallel_rank() > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.param_to_bucket_group = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
self.params_with_grad = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
# Track params with grad to enable direct setting
# of param.grad_added_to_main_grad
self.params_with_grad.append(param)
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def _allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor
):
param_and_grad_dtype_to_params = {}
param_and_grad_dtype_to_offsets = {}
param_and_grad_dtype_to_indices = {}
# Group parameters by their gradient type.
for param in input_params:
assert param.requires_grad
param_dtype = param.dtype
if is_float8tensor(param):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype = torch.uint8
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset = param_and_grad_dtype_to_offsets.get((param.dtype, grad_dtype), 0)
param_and_grad_dtype_to_offsets[(param.dtype, grad_dtype)] = offset + 1
indices = param_and_grad_dtype_to_indices.get((param_dtype, grad_dtype), [])
indices.append(offset)
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)] = indices
if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True)
== target_gradient_scaling_factor
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
_ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
param_and_grad_dtype_to_indices[(param_dtype, grad_dtype)],
)
)
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
# bucket group.
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
if self.ddp_config.num_distributed_optimizer_instances > 1:
assert (
self.ddp_config.use_distributed_optimizer
), 'Partial DistOpt cannot be used without DistOpt'
communication_stream = torch.cuda.Stream(device=torch.cuda.current_device())
for bucket_group in bucket_groups:
bucket_group.inter_distributed_optimizer_instance_group = (
parallel_state.get_inter_partial_data_parallel_group()
)
bucket_group.communication_stream = communication_stream
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
num_bucket_groups = len(bucket_groups)
for i in range(1, num_bucket_groups):
bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (
bucket_groups[num_bucket_groups - i - 1]
)
# Create map from param to bucket group, used in pre_hook.
for bucket_group in bucket_groups:
for bucket in bucket_group.buckets:
for param in bucket.params_list:
self.param_to_bucket_group[param] = bucket_group
return buffers, bucket_groups
if config.calculate_per_token_loss:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size()
)
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self.buffers, self.bucket_groups = _allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
),
gradient_scaling_factor=gradient_scaling_factor,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
_allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_expert_data_parallel_group(),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_backward_post_hook(param))
self.grad_accs.append(grad_acc)
self.use_forward_hook = (
self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather
)
self.remove_forward_pre_hook_handles = {}
if self.use_forward_hook:
self.enable_forward_pre_hook()
self.overlap_param_gather_with_optimizer_step = False
def enable_forward_pre_hook(self):
"""
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert self.use_forward_hook
assert len(self.remove_forward_pre_hook_handles) == 0
# Register forward pre-hook for all sub-modules.
for module in self.module.modules():
self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook(
self._make_forward_pre_hook()
)
def disable_forward_pre_hook(self):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules.
for module in self.module.modules():
assert self.remove_forward_pre_hook_handles[module] is not None
self.remove_forward_pre_hook_handles[module].remove()
del self.remove_forward_pre_hook_handles[module]
assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters.
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather).
"""
def hook(module, *unused):
assert (
self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True"
# Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a
# .requires_grad field equal to False).
if param not in self.param_to_bucket_group:
continue
assert param.requires_grad
# If aligning param all-gather across pipeline stages, all-gather is dispatched
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
# If overlapping param all-gather with optimizer step, then all-gather has
# already been dispatched in optimizer step.
skip_next_bucket_dispatch = (
self.ddp_config.align_param_gather
or self.overlap_param_gather_with_optimizer_step
)
self.param_to_bucket_group[param].finish_param_sync(
skip_next_bucket_dispatch=skip_next_bucket_dispatch
)
return hook
def _make_backward_post_hook(self, param: torch.nn.Parameter):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def hook(*unused):
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
self.param_to_bucket_group[param].register_grad_ready(param)
return hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = False
try:
yield
finally:
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync:
# If overlapping param AG with optimizer step, AG should not be dispatched again
# in forward_backward_step.
if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
return
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_param_sync(force_sync=force_sync)
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.finish_grad_sync()
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.params_with_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
bucket_group.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_expert_data_parallel_group()
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True, partial_data_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
overlap_param_gather: bool = False
"""If true, overlap param all-gather with forward compute."""
align_param_gather: bool = False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
num_distributed_optimizer_instances: int = 1
"""Sets the factor by which the DP domain is sharded to have the partial DistOpt
enabled. Defaults to 1, which means DistOpt is across entire DP domain.
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment