Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
......@@ -28,9 +28,10 @@ async_calls = AsyncCallsQueue()
def get_default_strategy(action: StrategyAction, backend: str, version: int):
"""Retrieves a default strategy for a given action, backend and version."""
error_hint: str = None
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
error_hint = ' Please install `zarr` and `tensorstore!=0.1.46` packages'
from .tensorstore import register_default_tensorstore_strategies
register_default_tensorstore_strategies()
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import os
from typing import Union
from torch.distributed.checkpoint import FileSystemReader, Metadata
class CachedMetadataFileSystemReader(FileSystemReader):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def __init__(self, path: Union[str, os.PathLike]) -> None:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super().__init__(path=path)
self._cached_metadata = None
def read_metadata(self) -> Metadata:
"""
Read metadata from file system, caching for subsequent calls.
Returns:
Metadata: Checkpoint metadata.
"""
if self._cached_metadata is None:
self._cached_metadata = super().read_metadata()
return self._cached_metadata
......@@ -69,7 +69,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
......@@ -95,12 +95,12 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try:
loaded_obj = torch.load(load_path)
loaded_obj = torch.load(load_path, weights_only=False)
except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(old_load_path)
loaded_obj = torch.load(old_load_path, weights_only=False)
except FileNotFoundError:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import gc
import dataclasses
import logging
import os
import queue
from contextlib import contextmanager
from functools import partial
from heapq import heappop, heappush
from itertools import chain
from operator import itemgetter
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
......@@ -20,6 +22,8 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
from .async_utils import _disable_gc
logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
......@@ -35,19 +39,6 @@ def _get_write_results_queue():
return _results_queue
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class FileSystemWriterAsync(FileSystemWriter):
"""
Async-enabled implementation of FileSystemWriter using file IO.
......@@ -76,6 +67,8 @@ class FileSystemWriterAsync(FileSystemWriter):
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
self.can_run_decentralized_global_plan: bool = True
# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
......@@ -99,7 +92,7 @@ class FileSystemWriterAsync(FileSystemWriter):
self.thread_count > 1
), "thread_count must be at least 2 if separation_hint is provided"
bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count
item_buckets = _split_by_size_and_type(bins, plan.items, self.separation_hint)
item_buckets = _split_by_size_and_type(bins, plan.items)
logger.debug(f"bucket_prep, time: {time() - start}")
start = time()
......@@ -113,6 +106,23 @@ class FileSystemWriterAsync(FileSystemWriter):
file_count += 1
return file_name
def _clone_if_needed(ten: torch.Tensor):
"""Clone if we detect incontiguous storage for CPU tensors
Makes sure we perform a `clone` only if we detect incontiguous storage,
so that we don't blow up host memory unnecessarily.
TODO: For persistent worker, this work should be changed to move the cpu tensor
to shared_memory.
"""
ten = ten.detach()
if ten.device.type != "cpu":
# We do D2H later when the async_request is scheduled for both sync / async
# checkpointing
return ten
is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize
return ten.clone() if is_view else ten
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = []
for group_name, group_buckets in _split_by_separation_hint(
......@@ -125,7 +135,7 @@ class FileSystemWriterAsync(FileSystemWriter):
if item.type == WriteItemType.BYTE_IO
]
tensor_data = [
(item, planner.resolve_data(item).detach().to("cpu", non_blocking=True))
(item, _clone_if_needed(planner.resolve_data(item)))
for item in bucket
if item.type != WriteItemType.BYTE_IO
]
......@@ -147,23 +157,49 @@ class FileSystemWriterAsync(FileSystemWriter):
end = time()
logger.debug(f"D2H and push, time: {end - start}")
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]:
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data
- arguments to that function
1) the function that saves the data.
2) the function that stages the GPU tensors to a destination for async checkpointing.
This function should be self-contained.
3) arguments to that function in 1).
"""
if not self.write_buckets:
return None, ()
return (self.write_preloaded_data_multiproc, (self.write_buckets, self.results_queue))
return None, None, ()
return (
self.write_preloaded_data_multiproc,
partial(self.preload_tensors, self.write_buckets, True),
[torch.distributed.get_rank(), self.write_buckets, self.results_queue],
)
@staticmethod
def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]:
"""Preload tensors in state_dict to host memory through CPU memory
Args:
write_buckets(List): List of `WriteBucket`,
which includes what to be saved in a checkpoint
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result = []
for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
if non_blocking:
torch.cuda.synchronize()
return result
@staticmethod
@_disable_gc()
def write_preloaded_data_multiproc(
write_buckets: List[WriteBucket], global_results_queue: mp.Queue
rank, write_buckets: List[WriteBucket], global_results_queue: mp.Queue
) -> None:
"""
Performs saving data to storage with multiple processes.
......@@ -186,6 +222,7 @@ class FileSystemWriterAsync(FileSystemWriter):
(or an Exception) from parallel write processes to the main training process
Returns: None
"""
logger = logging.getLogger(__name__)
w_start = time()
write_results_or_exc: Union[dict, Exception] = dict()
ctx = mp.get_context('fork')
......@@ -234,20 +271,16 @@ class FileSystemWriterAsync(FileSystemWriter):
logger.error(err_msg)
write_results_or_exc = local_results_or_exc
break
else:
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()
logger.debug('FileSystemWriterAsync: collected worker results successfully')
global_results_queue.put(write_results_or_exc)
w_end = time()
logger.debug(
f"{w_end}, rank: {torch.distributed.get_rank()},"
f" write(sync,parallel): {w_end - w_start}"
)
logger.debug(f"{w_end}, rank: {rank}," f" write(sync,parallel): {w_end - w_start}")
@staticmethod
@_disable_gc()
......@@ -271,6 +304,8 @@ class FileSystemWriterAsync(FileSystemWriter):
Returns: None, the write result are put into the `queue`
"""
logger = logging.getLogger(__name__)
logger.debug(f'{local_proc_idx} started')
mem_before = _process_memory()
local_results = []
......@@ -288,6 +323,7 @@ class FileSystemWriterAsync(FileSystemWriter):
os.fsync(stream.fileno())
local_output = (local_proc_idx, local_results)
except Exception as e:
logger.debug(f'{local_proc_idx} failed')
local_output = (local_proc_idx, e)
results_queue.put(local_output)
......@@ -334,10 +370,23 @@ class FileSystemWriterAsync(FileSystemWriter):
)
return list(chain.from_iterable(write_results.values()))
def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Instead of assigning indices by plan order, uses PyT rank (same outcome).
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
def _split_by_size_and_type(
bins: int, items: List[WriteItem], separation_hint: Optional[str] = None
) -> List[List[WriteItem]]:
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
return dataclasses.replace(
local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_")
)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Splits write items according to item size into close to uniform bins.
......@@ -353,24 +402,32 @@ def _split_by_size_and_type(
if bins == 1:
return [items]
bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
bytes_items: List[WriteItem] = []
tensor_items: List[WriteItem] = []
for wi in items:
container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items
container.append(wi)
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_items.sort(key=_item_size, reverse=True)
# Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items):
buckets[i % bins].append(item)
# Then, assign tensors according to their sizes
for item in tensor_items:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)
# Sort tensor items by size in decreasing order once and store the size with item
sized_tensors = [(item, _item_size(item)) for item in tensor_items]
sized_tensors.sort(key=itemgetter(1), reverse=True)
# Use a min heap for bin assignment
# Store (total_size_of_bin, bin_index) tuples
heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)]
# Assign tensors using heap
for item, size in sized_tensors:
total_bin_size, bin_idx = heappop(heap)
buckets[bin_idx].append(item)
heappush(heap, (total_bin_size + size, bin_idx))
return buckets
......
......@@ -2,12 +2,13 @@
import logging
from pathlib import Path
from time import time
from typing import Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import Metadata
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing import ShardedObject, ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
......@@ -19,6 +20,7 @@ from megatron.core.dist_checkpointing.exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
exchange_loaded_objects_gather_object,
)
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica
from megatron.core.dist_checkpointing.strategies.base import (
......@@ -26,7 +28,12 @@ from megatron.core.dist_checkpointing.strategies.base import (
LoadShardedStrategy,
SaveShardedStrategy,
)
from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId
from megatron.core.dist_checkpointing.utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
_ShardId,
debug_time,
)
from megatron.core.dist_checkpointing.validation import (
determine_global_metadata,
validate_sharding_integrity,
......@@ -34,6 +41,8 @@ from megatron.core.dist_checkpointing.validation import (
logger = logging.getLogger(__name__)
T = TypeVar('T', ShardedObject, ShardedTensor)
class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
"""Wraps arbitrary strategy and distributes the save during `save`.
......@@ -170,7 +179,9 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
self.exchange_algo = exchange_algo
self.cached_distribution: Optional[ShardDistribution] = None
self.cached_global_metadata: Optional[Metadata] = None
@debug_time("FullyParallelLoadStrategyWrapper.load", logger)
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Distributes the load and calls underlying strategy only for parts of the state dict.
......@@ -200,18 +211,20 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
loaded_state_dict = {}
if torch.distributed.get_world_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
# Step 1 and 2: exchange load metadata and distribute the load
start = time()
precomputed_distribution = self.apply_loading_parallelization(sharded_state_dict)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
end = time()
logger.debug(f'self.apply_loading_parallelization took {end - start}s')
start = end
with debug_time("self.apply_loading_parallelization", logger):
precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization(
sharded_state_dict
)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
......@@ -219,88 +232,121 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
(sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = (
self._defer_loading_sharded_tensors(sharded_state_dict)
)
loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir)
end = time()
logger.debug(f'Base load of ShardedObjects took {end - start}s')
start = end
(sharded_objects, sharded_state_dict, to_load_objects, unloaded_objects) = (
self._defer_loading_sharded_objects(sharded_state_dict)
)
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
assert (
len(sharded_state_dict) == 0
), "sharded_state_dict is not empty after deferring tensors and objects"
with debug_time("base_load_ShardedObjects", logger):
# Load sharded objects first
loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir)
with debug_time("base_load_ShardedTensors", logger):
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
with debug_time("self.exchange_loaded_tensors", logger):
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
)
end = time()
logger.debug(f'Base load of ShardedTensors took {end - start}s')
start = end
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
with debug_time("torch.cuda.synchronize", logger):
torch.cuda.synchronize()
all_loaded_objects = exchange_loaded_objects_gather_object(loaded_objects)
if not set(unloaded_objects.keys()).issubset(all_loaded_objects.keys()):
missing_object_shards = set(unloaded_objects.keys()) - all_loaded_objects.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
f'Missing object shards after fully parallel loading: {missing_object_shards}'
)
sync_start = time()
torch.cuda.synchronize()
end = time()
logger.debug(f'torch.cuda.synchronize took {end - sync_start}s')
logger.debug(f'self.exchange_loaded_tensors took {end - start}s')
self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
self.fill_in_deferred_sharded_objects(sharded_objects, all_loaded_objects)
merge(loaded_state_dict, sharded_objects)
merge(loaded_state_dict, sharded_tensors)
if hasattr(self.base_strategy, "cached_global_metadata"):
self.cached_global_metadata = self.base_strategy.cached_global_metadata
return loaded_state_dict
@staticmethod
def _defer_loading_sharded_objects(
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedObject],
Dict[_ShardId, ShardedObject],
]:
return _defer_loading_sharded_items(sharded_state_dict, ShardedObject, _sharded_object_id)
@staticmethod
def _defer_loading_sharded_tensors(
self, sharded_state_dict: ShardedStateDict
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedTensor],
Dict[_ShardId, ShardedTensor],
]:
"""Divides state dict into parts loaded by this vs other ranks.
return _defer_loading_sharded_items(
sharded_state_dict, ShardedTensor, _sharded_tensor_shard_id
)
ShardedTensors with main replica_id will be loaded by this rank,
others will be received by other ranks (after loading from storage).
@staticmethod
def fill_in_deferred_sharded_objects(
sharded_state_dict: ShardedStateDict, loaded_objects: Dict[_ShardId, Any]
) -> None:
"""Fill in objects not loaded by current rank with objects from `loaded_objects` map.
Args:
sharded_state_dict (ShardedStateDict): state dict with ShardedTensor
that will be divided.
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with ShardedTensors
- ShardedStateDict: sub-state dict with non-ShardedTensors
- Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified
by shard ids. This is a mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *this* rank
- Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedObjects are completely replaced with corresponding objects.
loaded_objects (Dict[_ShardId, Any]): dict allowing to map
ShardedObject from the sharded_state_dict to loaded objects.
sharded_tensors, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedTensor)
Returns:
None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_objects, ShardedObject, _sharded_object_id
)
def wrap_non_main_replicas(x):
if isinstance(x, ShardedTensor):
# Assign shard to be loaded or not
if is_main_replica(x.replica_id):
to_load_shards[_sharded_tensor_shard_id(x)] = x
else:
unloaded_shards[_sharded_tensor_shard_id(x)] = x
return x
@staticmethod
def fill_in_deferred_sharded_tensors(
sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
dict_list_map_inplace(wrap_non_main_replicas, sharded_tensors)
return sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards
Returns:
None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_tensors, ShardedTensor, _sharded_tensor_shard_id
)
def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
......@@ -339,34 +385,6 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
return precomputed_distribution
def fill_in_deferred_sharded_tensors(
self, sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
"""
def fill_in_sharded_tensor(x):
if isinstance(x, ShardedTensor):
try:
x = loaded_tensors[_sharded_tensor_shard_id(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded tensor shard: {_sharded_tensor_shard_id(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_tensor, sharded_state_dict)
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
......@@ -437,3 +455,61 @@ def distribute_main_replicas_with_precomputed_distribution(
sh_ten.replica_id = 0
else:
sh_ten.replica_id = 1
def _defer_loading_sharded_items(
sharded_state_dict: ShardedStateDict, item_type: type, shard_id_func: Callable[[T], _ShardId]
) -> Tuple[ShardedStateDict, ShardedStateDict, Dict[_ShardId, T], Dict[_ShardId, T]]:
"""Divides state dict into parts loaded by this vs other ranks.
Args:
sharded_state_dict (ShardedStateDict): state dict with sharded items
that will be divided.
item_type: The type of sharded item (ShardedObject or ShardedTensor)
shard_id_func: Function to get the shard ID for the item type
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with sharded items
- ShardedStateDict: sub-state dict with non-sharded items
- Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank
- Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_items, remaining_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, item_type)
)
def wrap_non_main_replicas(x: Any) -> Any:
if isinstance(x, item_type):
shard_id = shard_id_func(x)
if is_main_replica(x.replica_id):
to_load_shards[shard_id] = x
else:
unloaded_shards[shard_id] = x
return x
dict_list_map_inplace(wrap_non_main_replicas, sharded_items)
return sharded_items, remaining_state_dict, to_load_shards, unloaded_shards
def _fill_in_deferred_sharded_items(
sharded_state_dict: ShardedStateDict,
loaded_items: Dict[_ShardId, Any],
item_type: type,
shard_id_func: Callable[[T], _ShardId],
) -> None:
"""Helper function to fill in items not loaded by current rank."""
def fill_in_sharded_item(x: Any) -> Any:
if isinstance(x, item_type):
try:
x = loaded_items[shard_id_func(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded item shard: {shard_id_func(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_item, sharded_state_dict)
......@@ -13,7 +13,7 @@ import logging
import math
from dataclasses import dataclass
from itertools import product
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch
......@@ -27,7 +27,6 @@ from megatron.core.dist_checkpointing.dict_utils import (
extract_matching_values,
)
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
StateDict,
......@@ -84,11 +83,7 @@ def is_nd_flattened_tensor(sh_ten: Any) -> bool:
Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
"""
return (
isinstance(sh_ten, ShardedTensor)
and sh_ten.flattened_range is not None
and len(sh_ten.global_shape) > 1
)
return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None
# information needed to restore. With current implementation, this is a nested state dict
......@@ -132,8 +127,12 @@ def apply_nd_flattened_tensors_reformulation(
try:
sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key]
except KeyError as e:
# Handle legacy checkpointing where 1-D flatten tensor metadata was not saved
if len(sh_ten.global_shape) == 1:
return sh_ten
raise CheckpointingException(
f'Missing reformulation metadata for tensor {sh_ten}. Existing keys: {reformulation_metadata.keys()}'
f'Missing reformulation metadata for tensor {sh_ten}. '
f'Existing keys: {reformulation_metadata.keys()}'
) from e
ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape
......@@ -235,13 +234,16 @@ def reformulate_single_nd_flattened_tensor(
):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset)
# `math.ceil` argument is an exact offset of the app next shard expressed in ckpt_local_shape units
# `math.ceil` argument is an exact offset of the app next shard expressed
# in ckpt_local_shape units
next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1))
overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset))
logger.debug(
f'Generated the following number of overlap shards for each dimension: {list(map(len, overlap_dim_offsets))}'
f' for fragmentation ckpt {ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} and chunk offset {sh_ten.local_chunk_offset_in_global()}'
f'Generated the following number of overlap shards for each dimension: '
f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt '
f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} '
f'and chunk offset {sh_ten.local_chunk_offset_in_global()}'
)
reformulated_sh_tens = {}
for chunk_offset in product(*overlap_dim_offsets):
......@@ -286,7 +288,8 @@ def reformulate_single_nd_flattened_tensor(
# For each ckpt shard, we fill the appropriate application shard part
dest_ten = app_non_flat_ten
src_ten = ckpt_ten.view(ckpt_local_shape)
# We don't need narrowing over `prepend_axis_num` axes so we take the [sh_ten.prepend_axis_num:] offsets slice
# We don't need narrowing over `prepend_axis_num` axes so we take
# the [sh_ten.prepend_axis_num:] offsets slice
for (
dim,
offset_for_saved_tensor,
......
......@@ -4,7 +4,7 @@
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -16,19 +16,37 @@ from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
from .torch import MCoreSavePlanner
logger = getLogger(__name__)
from dataclasses import fields
def _compare_dataclasses(obj1, obj2):
if type(obj1) != type(obj2):
return f"Objects are of different types: {type(obj1)} and {type(obj2)}"
differences = []
for field in fields(obj1):
value1 = getattr(obj1, field.name)
value2 = getattr(obj2, field.name)
if value1 != value2:
differences.append(f"{field.name}: {value1} != {value2}")
return differences if differences else "All fields are equal"
def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None,
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]:
loaded_all_plans: Optional[List[SavePlan]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]:
"""
First stage of saving a state dict to storage.
......@@ -62,7 +80,7 @@ def save_state_dict_async_plan(
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- metadata from planning (or None if we reuse cached global metadata)
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
......@@ -80,6 +98,7 @@ def save_state_dict_async_plan(
global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan
global_md_verify_reuse = False
def local_step():
nonlocal local_plan
......@@ -101,11 +120,34 @@ def save_state_dict_async_plan(
return all_local_plans
# Execute local and global planning
# Ideally we want to use the cached plan. Otherwise if the planner and storage_writer
# allow it (`can_run_decentralized_global_plan`) we gather the plans to create
# the metadata but prepare the plans independently on each rank.
# In the worst case we have to reduce_scatter all the plans.
start_plan = time()
if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable")
local_step()
central_plan = cached_central_plan
elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr(
storage_writer, 'can_run_decentralized_global_plan', False
):
local_plan = local_step()
global_md_verify_reuse = verify_global_md_reuse(
loaded_all_plans, local_plan, rank, dist_wrapper
)
if not loaded_all_plans or not global_md_verify_reuse:
all_local_plans = dist_wrapper.gather_object(local_plan)
if dist_wrapper.is_coordinator:
_, global_metadata = planner.create_global_plan(all_local_plans)
global_metadata.all_local_plans = all_local_plans
else:
logger.debug(f"rank: {rank}, Passed cached global metadata")
global_metadata = None
local_plan = planner.create_decentralized_global_plan(local_plan)
local_plan = storage_writer.prepare_decentralized_global_plan(local_plan)
central_plan = local_plan
else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan)
......@@ -118,13 +160,56 @@ def save_state_dict_async_plan(
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return (
(storage_writer, cast(Metadata, global_metadata), dist_wrapper),
(storage_writer, global_metadata, dist_wrapper),
central_plan,
local_plan,
cached_central_plan == central_plan,
global_md_verify_reuse,
)
def verify_global_md_reuse(
loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper
) -> bool:
"""
Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
Args:
loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint
(stored in checkpoint metadata).
local_plan: SavePlan, The local save plan.
rank: Current process rank.
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: True iff the global metadata reuse is possible.
"""
logger.debug(f"verifying reuse of global metadata")
if not loaded_all_plans:
global_md_verify_reuse = False
logger.debug("loaded global metadata reuse verification: no loaded plans passed")
elif len(loaded_all_plans) == dist_wrapper.get_world_size():
local_verify_reuse = all(
getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name)
for f in fields(local_plan)
if f.name != 'storage_data'
)
if not local_verify_reuse:
logger.debug(
f"local_verify_reuse is False: diffs -"
f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}"
)
all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda')
torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN)
# Check if all reduced results are True
global_md_verify_reuse = all_results.item() == 1
else:
global_md_verify_reuse = False
return global_md_verify_reuse
def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
) -> None:
......
File mode changed from 100755 to 100644
......@@ -55,6 +55,7 @@ from .base import (
StrategyAction,
register_default_strategy,
)
from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
from .filesystem_async import FileSystemWriterAsync
from .resharding import (
TensorReformulationMetadata,
......@@ -126,7 +127,9 @@ def flatten_state_dict(
def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None
sh_tens: List[ShardedTensor],
rank: Optional[int] = None,
load_legacy_1d_flatten_tensors: bool = False,
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
......@@ -138,13 +141,12 @@ def sharded_tensor_to_torch_sharded_tensor(
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors:
This function handles 2 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
2. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency:
(1) type are saved according to their original shape.
Type (2) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
......@@ -154,6 +156,8 @@ def sharded_tensor_to_torch_sharded_tensor(
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors
should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
......@@ -163,41 +167,21 @@ def sharded_tensor_to_torch_sharded_tensor(
some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1
for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous()
if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1:
# Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors
has_flattened_range = False
local_global_offsets = {}
prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs)
if is_flattened_range_1d:
# Type (2) case: 1D flattened ShardedTensors
for sh_ten in sh_tens:
assert len(sh_ten.global_offset) == 1, sh_ten
assert sh_ten.prepend_axis_num == 0, sh_ten
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
global_shape = some_sh_ten.global_shape
offsets_shape = (
some_sh_ten.local_shape
) # local shape is not flattened, we need it for chunk offsets
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
[
sh_ten.global_offset[0] + sh_ten.flattened_range.start
], # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
elif has_flattened_range:
if has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
......@@ -250,10 +234,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if is_flattened_range_1d:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
elif has_flattened_range:
if has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
......@@ -270,7 +251,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range and not is_flattened_range_1d:
if has_flattened_range:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
......@@ -296,7 +277,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {}
if has_flattened_range and not is_flattened_range_1d:
if has_flattened_range:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten
......@@ -305,6 +286,7 @@ def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
load_legacy_1d_flatten_tensors: bool = False,
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format.
......@@ -348,7 +330,9 @@ def mcore_to_pyt_state_dict(
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank)
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
sh_tens, rank, load_legacy_1d_flatten_tensors
)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten
......@@ -460,6 +444,7 @@ class MCoreSavePlanner(DefaultSavePlanner):
*args,
dedup_replicated_tensors: Optional[bool] = None,
nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None,
can_run_decentralized_global_plan: bool = True,
**kwargs,
) -> None:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
......@@ -468,6 +453,14 @@ class MCoreSavePlanner(DefaultSavePlanner):
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
self.can_run_decentralized_global_plan = can_run_decentralized_global_plan
if can_run_decentralized_global_plan:
assert (
not dedup_replicated_tensors
), 'Cannot run decentralized plan with dedup_replicated_tensors=True'
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
def create_local_plan(self) -> SavePlan:
"""Adds IOBytes write request on non-coordinator ranks."""
......@@ -503,6 +496,23 @@ class MCoreSavePlanner(DefaultSavePlanner):
metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans)))
return global_plan, metadata
def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Nothing to do, just some checks.
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan'
return local_plan
def transform_object(self, write_item: WriteItem, object: Any):
"""Make no transformations - bytes objects are already serialized."""
return object
......@@ -535,6 +545,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
else:
expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if loaded_shape != expected_shape:
if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1:
# Handle legacy 1-D flattened tensors checkpoint format
# where the global shape is not stored in the metadata
expected_shape = sh_ten.global_shape
if loaded_shape == expected_shape:
continue
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor'
......@@ -634,6 +650,8 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
self.separation_hint = separation_hint
self.validated_loaded_metadata_reuse = False
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
......@@ -663,7 +681,14 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
args_cached_plans = None
loaded_all_plans = None
if self.use_cached_ckpt_structure:
loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None)
if loaded_all_plans is None:
logger.debug(
"no all_local_plans in metadata - can't verify global metadata reuse..."
)
args_cached_plans = (
self.cached_central_plan,
self.cached_local_plan,
......@@ -675,24 +700,44 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
self.cached_central_plan,
self.cached_local_plan,
self.validated_cache_reuse,
self.validated_loaded_metadata_reuse,
) = save_state_dict_async_plan(
pyt_state_dict,
writer,
None,
coordinator,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica),
planner=MCoreSavePlanner(
dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False
),
cached_ckpt_structure=args_cached_plans,
loaded_all_plans=loaded_all_plans,
)
rank = torch.distributed.get_rank()
if self.use_cached_ckpt_structure:
if self.validated_cache_reuse:
if (
loaded_all_plans
and self.cached_global_metadata
and self.validated_loaded_metadata_reuse
):
if coordinator == rank:
logger.debug(
f"rank: {rank}, reuse global metadata from loaded"
f" .metadata, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
elif self.validated_cache_reuse:
logger.debug(f"rank: {rank}, cache validated")
if save_state_dict_ret[1]: # when global_metadata is not cached
self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata
# Only Coordinator rank holds cached global_metadata
# (None is returned for global_metadata)
elif coordinator == rank:
logger.debug(f"rank: {rank}, reuse metadata, {save_state_dict_ret[1]}")
logger.debug(
f"rank: {rank}, reuse global metadata cached from previous"
f" save iteration, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
......@@ -700,13 +745,13 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest:
save_fn_args = writer.get_save_function_and_args()
save_fn, save_args = save_fn_args
save_fn, preload_fn, save_args = save_fn_args
def finalize_fn():
save_state_dict_async_finalize(*save_state_dict_ret)
torch.distributed.barrier()
return AsyncRequest(save_fn, save_args, [finalize_fn])
return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
def can_handle_sharded_objects(self):
return True
......@@ -736,6 +781,12 @@ def get_reformulation_metadata(
'nd_reformulated_orig_global_shape'
]
except KeyError as e:
if len(sh_ten.global_shape) == 1:
warnings.warn(
f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. '
'Skip metadata reformulation.'
)
continue
raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
......@@ -750,6 +801,10 @@ def get_reformulation_metadata(
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format."""
def __init__(self):
self.cached_global_metadata: Optional[Metadata] = None
super().__init__()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
......@@ -761,10 +816,18 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, reformulation_metadata
)
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors = False
for sh_ten in nested_values(sharded_state_dict):
if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata:
has_legacy_1d_flattened_tensors = True
break
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
......@@ -776,15 +839,23 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
pyt_state_dict = mcore_to_pyt_state_dict(
sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
fsr = CachedMetadataFileSystemReader(checkpoint_dir)
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
fsr,
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
self.cached_global_metadata = (
fsr.read_metadata()
) # no storage interaction thanks to caching
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial, wraps
from itertools import chain
from logging import DEBUG, INFO, StreamHandler, getLogger
from logging import getLogger
from operator import attrgetter, itemgetter
from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy
from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array
from .tensorstore import _load_from_array, open_ts_array
from .zarr import flatten_range, load_zarr_based_sharded_metadata
_import_trigger = None
......@@ -26,9 +25,16 @@ _import_trigger = None
timers = defaultdict(list)
logger = getLogger(__name__)
logger.warning(
'megatron.core.dist_checkpointing.two_stage module is deprecated'
' and will be removed in Megatron-Core v0.12. Please use'
' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.'
)
def timed(verbose=True):
"""Timing decorator."""
def timed_dec(fn):
name = fn.__name__
......@@ -59,6 +65,7 @@ class _ShardedTensorMetadata:
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
"""Id of a sharded tensor."""
return (sharded_tensor.key, sharded_tensor.global_offset)
......@@ -101,6 +108,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
self.global_rank = torch.distributed.get_rank()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Main load method."""
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
......@@ -109,6 +117,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
return sharded_state_dict
def summarize_load_times(self):
"""Summarize load times."""
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
......@@ -124,6 +133,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
"""Load tensor from storage."""
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array(
ten_meta.sharded_tensor_no_data,
......@@ -136,12 +146,15 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed()
def maybe_init_gloo_group(self):
"""Create Gloo groups."""
if not self.cpu_transfer:
return
all_groups = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
all_groups = set(tuple(sorted(gr)) for gr in all_groups)
for group_ranks in sorted(all_groups):
# "two_stage" module will be deprecated, so not replace new_group()
# with ...parallel_state.create_group() func setting group_desc here.
gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
if self.global_rank in group_ranks:
self.data_parallel_group = gloo_pg
......@@ -211,7 +224,8 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
)
logger.debug(
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}\
({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
)
torch.distributed.broadcast(
exchange_tensor, group=self.data_parallel_group, src=src_rank
......
File mode changed from 100755 to 100644
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace, merge, nested_values
from .exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
)
from .mapping import ShardedObject, ShardedStateDict, ShardedTensor, StateDict, apply_factory_merges
from .state_dict_utils import load_preprocess, save_preprocess
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
debug_time,
extract_sharded_base,
zip_strict,
)
from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
@dataclass
class MCoreTensorAwareStateDict(TensorAwareStateDict):
"""
MCore-specific class defining the interface between the MCore state dict and checkpoint manager.
This class distinguishes between raw objects, the common state dict, and sharded state dicts
(tensor parts). It also handles optional metadata needed for fully parallel save/load.
"""
common: StateDict
sharded_state_dict: ShardedStateDict
_is_hollow: bool = False
@staticmethod
def _validate_params(algo):
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
@staticmethod
def _get_distribution(
fully_parallel, sharded_part, parallelization_group, cached_distribution=None
):
if fully_parallel:
if cached_distribution is None:
distribution = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
logger.debug(f'MCore_TASD._get_distribution calculated distribution')
else:
distribution = cached_distribution
logger.debug(f'MCore_TASD._get_distribution used cache')
else:
distribution = (None, None, None, None)
logger.debug(f'MCore_TASD._get_distribution returned empty distribution')
return distribution
@staticmethod
def _remove_redundant_data(
fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group
):
if fully_parallel:
for sh_base in nested_values(sharded_part):
# TODO remove redundant objects as well
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_to_saving_rank[shard_id] != torch.distributed.get_rank(
group=parallelization_group
):
sh_base.data = None
@classmethod
@debug_time("from_state_dict", logger)
def from_state_dict(
cls,
sharded_state_dict: ShardedStateDict,
algo: str = 'fully_parallel',
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
cached_metadata: ShardDistribution = None,
) -> Tuple[TensorAwareStateDict, ShardDistribution]:
"""
Constructs a TensorAwareStateDict from a sharded state dictionary.
This method preprocesses the input `sharded_state_dict`, validates parameters,
and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`.
Args:
sharded_state_dict: The input sharded state dictionary to be converted.
algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'.
- 'fully_parallel' enables fully parallel initialization.
parallelization_group (Optional): A distributed process group for parallelization.
cached_metadata (Optional): Precomputed metadata from previous saves.
- Reuses data that doesn't need recalculation, optimizing the creation process.
Returns:
TensorAwareStateDict: An instance initialized with the provided sharded state dictionary
and optional cached metadata.
- The metadata is stored in memory to speed up future saves.
"""
with debug_time("_get_distribution", logger):
cls._validate_params(algo)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(
sharded_state_dict, cached_metadata is None
)
cacheable_distribution = cls._get_distribution(
fully_parallel, sharded_part, parallelization_group, cached_metadata
)
if cacheable_distribution is not None:
shard_to_saving_rank, _, _, _ = cacheable_distribution
cls._remove_redundant_data(
fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group
)
return (
MCoreTensorAwareStateDict(common=common_state_dict, sharded_state_dict=sharded_part),
cacheable_distribution,
)
@property
def is_hollow(self):
"""
True iff tensors had been extracted and have not been inserted back yet.
"""
return self._is_hollow
@property
def _sharded_tensors(self):
# Three possible states for sharded_tensor:
# 1. sharded_tensor with data (.data = tensor)
# 2. sharded_tensor hollow (.data = None, .orig_device = orig_device)
# 3. removed sharded_tensor (.data = None, no device information)
# TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data
if self.is_hollow:
for sh_base in nested_values(self.sharded_state_dict):
# FIXME: Hacky way to store the original device of the popped tensor
if isinstance(sh_base, ShardedTensor) and hasattr(sh_base, 'orig_device'):
yield sh_base
else:
for sh_base in nested_values(self.sharded_state_dict):
if isinstance(sh_base, ShardedTensor) and sh_base.data is not None:
yield sh_base
@property
def tensors(self) -> Iterator[torch.Tensor]:
"""
Get the tensor data from the state dict.
"""
assert not self.is_hollow # TODO raise exception
return map(lambda sh_ten: sh_ten.data, self._sharded_tensors)
@property
def common_state_dict(self) -> Dict:
"""
Get the common state dict from the state dict.
"""
return self.common
def pop_tensors(self) -> List[torch.Tensor]:
"""
Extracts the tensor data from the wrapped state dict, preserving metadata.
Replaces the tensor data in sharded_tensors with device type of extracted tensors.
After this operation, the state dictionary is "hollow", containing no tensor data.
Further calls to `pop_tensor` will raise an error.
@return List of extracted tensors
"""
assert not self.is_hollow # TODO raise exception
result = []
for sh_ten in self._sharded_tensors:
result.append(sh_ten.data)
# FIXME: Hacky way to store the original device, which is not included in the metadata
setattr(sh_ten, 'orig_device', sh_ten.data.device.type)
sh_ten.data = None
self._is_hollow = True
return result
def insert_tensors(self, tensor_data: Iterable[torch.Tensor]):
"""
Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values
Value of `self` is considered to be the same after:
```
self.insert_tensors(self.pop_tensors())
```
"""
assert self.is_hollow # TODO raise exception
for sh_ten, ten in zip_strict(self._sharded_tensors, tensor_data):
# FIXME: Hacky way to store the original device
if sh_ten.orig_device == ten.device.type:
delattr(sh_ten, 'orig_device')
# Tensor might be on non-original device
sh_ten.data = ten
self._is_hollow = False
def init_tensors(self):
"""
Initializes empty tensors with the same properties as the original tensors.
This function should only be called after the original tensors have been popped.
It ensures that the newly created empty tensors match the shape,
dtype, and device of the originals, but contain no data.
"""
assert self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
# Hacky way to retrieve the original device
sh_ten.init_data(sh_ten.orig_device)
delattr(sh_ten, 'orig_device')
self._is_hollow = False
def copy_tensors_to_cpu(self, non_blocking=False):
"""
Stores CPU copies of tensors in the state_dict, replacing the originals,
but without destroying them.
The original devices are remembered for restoration with restore_tensor_device().
Using non_blocking=True allows for asynchronous copying.
"""
assert not self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
if sh_ten.data.device.type == 'cpu':
# Skip cloning if it's already confirmed to be a copy
if not hasattr(sh_ten, 'orig_device'):
sh_ten.data = sh_ten.data.clone()
else:
# FIXME: Hacky way to store the original device
if not hasattr(sh_ten, 'orig_device'):
setattr(sh_ten, 'orig_device', sh_ten.data.device.type)
sh_ten.data = sh_ten.data.detach().to("cpu", non_blocking=non_blocking)
def restore_tensor_device(self, non_blocking=True):
"""
Restores all tensors to their original devices, if a move is required.
Using non_blocking=True allows for asynchronous copying.
"""
assert not self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
# FIXME: Hacky way to store the original device
if hasattr(sh_ten, 'orig_device'):
sh_ten.data = sh_ten.data.to(sh_ten.orig_device, non_blocking=non_blocking)
delattr(sh_ten, 'orig_device')
def _insert_sharded_data(
self, fully_parallel, sharded_part, parallelization_group, exchange_algo
):
loaded_tensors = {}
for sh_ten in self._sharded_tensors:
loaded_tensors[_sharded_tensor_shard_id(sh_ten)] = sh_ten.data
if fully_parallel:
with debug_time("_get_distribution", logger):
distribution = self._get_distribution(
fully_parallel, sharded_part, parallelization_group
)
if distribution is not None:
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
# TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
with debug_time("exchange_by_distribution", logger):
loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
distribution,
parallelization_group,
exchange_algo,
)
torch.cuda.synchronize()
loaded_objects = {}
for sh_base in nested_values(self.sharded_state_dict):
if not isinstance(sh_base, ShardedTensor):
assert isinstance(sh_base, ShardedObject)
loaded_objects[_sharded_object_id(sh_base)] = sh_base.data
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
assert shard_id in loaded_tensors, (x, shard_id, loaded_tensors.keys())
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part)
@debug_time("to_state_dict", logger)
def to_state_dict(
self,
sharded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Convert tensor-aware dict back to the original state_dict
"""
with debug_time("load_preprocess_and_state_dict_manipulations", logger):
assert not self.is_hollow # TODO raise exception
self._validate_params(algo)
fully_parallel = algo == 'fully_parallel'
# __adding__ common part
recreated_state_dict = dict_list_map_outplace(lambda x: x, self.common)
if not sharded_state_dict:
return recreated_state_dict
# TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
with debug_time("validate_sharding_integrity", logger):
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
with debug_time("_insert_sharded_data", logger):
self._insert_sharded_data(
fully_parallel, sharded_part, parallelization_group, exchange_algo
)
with debug_time("apply_factory_merges", logger):
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
import logging
from contextlib import contextmanager
from time import time
from typing import Dict, Optional, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
......@@ -20,6 +22,18 @@ from .mapping import (
_ShardId = Tuple[str, tuple, Optional[tuple]]
def zip_strict(*args):
"""
Alternative to Python's builtin zip(..., strict=True) (available in 3.10+).
Apart from providing functionality in earlier versions of Python is also more verbose.
(Python's zip does not print lengths, only which iterable has finished earlier)
"""
args = [list(a) for a in args]
lens = [len(a) for a in args]
assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!"
return zip(*args)
def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId:
"""Unique id of the sharded tensor data.
......@@ -217,3 +231,89 @@ def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[
return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
fallback_logger = logging.getLogger(__name__)
__LOGGER_NAME_STACK = []
__LOGGER_STACK = []
@contextmanager
def logger_stack(name: Optional[str] = None, current_logger: Optional[logging.Logger] = None):
"""Context manager for managing logger and name stack.
Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical
logging and contextual logger usage. Ensures the logger stack is restored afterward.
Args:
name (str, optional): Name to add to the logger stack. Defaults to None.
current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in
the stack or a fallback if none exist.
Yields:
Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and
the current logger for the block.
Example:
with logger_stack("scope", logger):
logger.info("Log within 'scope'")
"""
if name:
__LOGGER_NAME_STACK.append(name)
if current_logger:
__LOGGER_STACK.append(current_logger)
last_logger = current_logger
elif __LOGGER_STACK:
last_logger = __LOGGER_STACK[-1]
else:
last_logger = fallback_logger
try:
yield ".".join(__LOGGER_NAME_STACK), last_logger
finally:
if name and __LOGGER_NAME_STACK:
__LOGGER_NAME_STACK.pop(-1)
if current_logger and __LOGGER_STACK:
__LOGGER_STACK.pop(-1)
@contextmanager
def debug_time(
name: str, logger: Optional[logging.Logger] = None, threshold: float = float("-inf"), level=None
):
"""Simple context manager for timing functions/code blocks.
Args:
name (str): Label describing the code being measured.
logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger.
threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster.
level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset;
WARNING otherwise.
"""
with logger_stack(name, logger) as (stacked_name, last_logger):
start = time()
try:
yield
finally:
result = time() - start
if result < threshold:
return
if level is None:
level = logging.DEBUG if threshold == float("-inf") else logging.WARNING
last_logger.log(level, f"{stacked_name} took {result:.4f}s")
def debug_msg(msg: str):
"""Logs a debug message using the current logger stack.
This function formats and logs a debug message with the current logger
and name stack, preserving context from the logger_stack context manager.
Args:
msg (str): The message to be logged at the debug level.
Example:
debug_msg("Checkpoint initialized")
# Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name")
"""
with logger_stack(None, None) as (stacked_name, last_logger):
last_logger.debug(f"{stacked_name} {msg}")
......@@ -412,7 +412,7 @@ def validate_sharding_integrity(
CheckpointingException for invalid access pattern
"""
if common_state_dict:
if common_state_dict is not None:
_validate_common_state_dict(common_state_dict)
if torch.distributed.get_rank() != 0:
......@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
# For each shard with at least 1 flattened tensor in it, the above
# `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
# The only thing that can go wrong at this point is that some shard don't have
# *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
if not torch.all(shard_access_cnt == 1):
raise CheckpointingException(
f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
)
def _compute_shards_access(rank_sharding):
......@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
expected_size = np.product(local_shape)
if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
)
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .loss_func import loss_func
from .model_provider import model_provider
from .fully_sharded_data_parallel import FullyShardedDataParallel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import functools
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.distributed.custom_fsdp.param_and_grad_buffer import (
AllGatherPipeline,
BucketingPolicy,
GradReducePipeline,
ParamAndGradBuffer,
PrefetchOrder,
)
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import is_submodule, log_single_rank
logger = logging.getLogger(__name__)
class TrainingState(Enum):
"""States of a FSDP parameter group, which are coupled with
the sharding activity of parameters and gradients during training."""
# From pre-forward before post-forward, where parameters should be unsharded
FORWARD = auto()
# Prior to backward computation, where parameters should be unsharded
PRE_BACKWARD = auto()
# After backward computation, where gradients should be re-sharded
POST_BACKWARD = auto()
# Before and after module forward computaton or before pre-backward and
# after post-backward states, where no un/sharding activity happens
IDLE = auto()
class FullyShardedDataParallel(_BaseDataParallel):
"""Fully Sharded Data Parallel training for MCore models.
A distributed training wrapper that shards model parameters, gradients and optimizer
states across data parallel workers. Integrates seamlessly with MCore's tensor
and expert parallelism features.
We supports following modes:
- no_shard: Traditional data parallel training without parameter sharding.
- optim: Shards optimizer states, this is conceptually close to "ZeRO-1", and
main weights for mixed precision training, meanwhile the following `optim_grads`
and `optim_grads_params` will also sharding main weights
during mixed-precision training, omitted without detailed notation.
- optim_grads: Shards gradients and optimizer states, this is conceptually close to "ZeRO-2".
- optim_grads_params: Shards parameters, gradients and optimizer states, this
is conceptually close to "ZeRO-3".
Key Features:
- Compatible with MCore's tensor, context and expert parallelism
- Automatic mixed precision training (BF16/FP8)
- Gradient accumulation and bucketing
- Optimized activation recompute with shard-aware communication: When recomputing
a whole Transformer layer, gather parameters once for both the recomputation
and backward computation
- Compatible with MCore's distributed checkpointing
Args:
config: Transformer config object.
ddp_config: FullyShardedDataParallel config object.
module: Underlying model.
fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
i.e., the minimum releasable model unit. If not provided, defaults to
[TransformerLayer, LanguageModelEmbedding] for GPT-like models.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket.
Examples:
>>> model = GPTModel(config)
>>> model = FullyShardedDataParallel(
... config,
... model,
... ddp_config,
... fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding],
... )
"""
# TODO: add hybrid FSDP (shard model states in a partial DP domain)
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
):
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
self.bucket_size = self.ddp_config.bucket_size
if disable_bucketing:
self.bucket_size = None
self.device = device if device else torch.cuda.current_device()
self.param_to_bucket_group = {}
if fsdp_unit_modules is not None:
self.fsdp_unit_modules = fsdp_unit_modules
else:
self.fsdp_unit_modules = [TransformerLayer]
if not getattr(self.module, "share_embeddings_and_output_weights", False):
self.fsdp_unit_modules.append(LanguageModelEmbedding)
self.main_weights = True
self.data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
self.expert_data_parallel_group = parallel_state.get_expert_data_parallel_group()
# Determine if we should delay the gradient reduction.
self.is_delay_grad_reduce = self.ddp_config.data_parallel_sharding_strategy in [
"no_shard",
"optim",
]
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
assert self.ddp_config.overlap_param_gather
if not self.is_delay_grad_reduce:
assert self.ddp_config.overlap_grad_reduce
self._init_fsdp_param_and_grad_buffer()
self._register_fsdp_hooks(self.module)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
def _init_fsdp_param_and_grad_buffer(self):
if self.config.calculate_per_token_loss:
# We don't need to scale the gradients in this case.
gradient_scaling_factor = None
expert_gradient_scaling_factor = None
else:
if self.ddp_config.average_in_collective:
# FIXME(@jianbinc): Will fix this issue based on Parallel Folding's EDP patch MR.
raise Exception("Not supported")
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Initialize the param and grad buffer.
self.data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy
self.param_to_name = {p: name for name, p in self.module.named_parameters()}
self.param_and_grad_buffer = ParamAndGradBuffer(
self.ddp_config,
self.module,
bucketing_policy=BucketingPolicy(
suggested_bucket_size=self.bucket_size,
fsdp_unit_modules=(
# Only when model weights need to be sharded, we need to
# identify the minimum releasable model unit, which is the
# FSDP Unit Module.
self.fsdp_unit_modules
if self.data_parallel_sharding_strategy == "optim_grads_params"
else []
),
data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
),
data_parallel_group=self.data_parallel_group,
expert_data_parallel_group=self.expert_data_parallel_group,
preserve_fp32_weights=self.ddp_config.preserve_fp32_weights,
grad_reduce_in_fp32=self.ddp_config.grad_reduce_in_fp32,
gradient_scaling_factor=gradient_scaling_factor,
expert_gradient_scaling_factor=expert_gradient_scaling_factor,
device=self.device,
reset_parameters_for_meta_device_init_module=self.config.init_model_with_meta_device,
)
self.param_and_grad_buffer
self.side_stream_for_buffer_copy_and_grad_accum = torch.cuda.Stream()
# Initialize the reduce-scatter pipeline.
self.grad_reduce_pipeline = GradReducePipeline(
self.param_and_grad_buffer, cuda_stream=self.side_stream_for_buffer_copy_and_grad_accum
)
# Initialize the all-gather pipeline.
self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
self.suggested_RS_queue_capacity = self.ddp_config.suggested_communication_unit_size
self.suggested_AG_prefetch_size = self.ddp_config.suggested_communication_unit_size
def _register_fsdp_hooks(self, root_module):
"""Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
This function sets up various hooks required for FSDP operations, including parameter
resharding/unsharding and gradient handling. The registered hooks are:
- Pre-forward hook: Unshards parameters before forward pass
- Post-forward hook: Reshards parameters after forward pass
- Pre-backward hook: Unshards parameters before backward pass
- Post-backward hook: Reshards parameters after backward pass
- Gradient accumulation hook: Handles gradient accumulation and reduction across devices
Args:
root_module: The PyTorch module to register FSDP hooks on
Note:
These hooks are essential for FSDP's memory efficiency as they manage:
1. Dynamic parameter sharding/unsharding to reduce memory footprint
2. Proper gradient synchronization across distributed processes
3. Gradient accumulation for large batch training
Returns:
None
"""
# Initialize module training state.
for m in root_module.modules():
setattr(m, "_training_state", TrainingState.IDLE)
self.forward_pre_hooks = {}
self.forward_hooks = {}
self.backward_pre_hooks = {}
"""
An FSDP unit is a module designed to manage the lifecycle of model parameters
in Fully Sharded Data Parallel (FSDP) training. It ensures that parameters
are only used within the module and are released immediately after
the forward and backward computations are completed.
This approach is crucial for efficient memory management, as releasing
parameters too early can lead to issues if other computations depend on them.
`optim` and `optim_grads` do not require FSDP units because they do not
shard model parameters.
"""
if self.data_parallel_sharding_strategy != "optim_grads_params":
fsdp_unit_modules = []
else:
fsdp_unit_modules = self.fsdp_unit_modules
def release_module_parameters(module, *unused):
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.release_bucket(bucket_id)
if not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
release_params_fp8_transpose_cache(module.parameters())
def release_params_fp8_transpose_cache(params):
for param in params:
if is_float8tensor(param):
param._transpose_invalid = True
param._transpose = None
def all_gather_module_parameters(
module,
*unused,
prefetch=True,
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
wait_bucket_ready=True,
):
wait_list = []
ag_pipeline = self.all_gather_pipeline
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
ag_pipeline.queue_bucket_to_all_gather(
bucket_id,
prefetch=prefetch,
prefetch_order=prefetch_order,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
wait_list.append(bucket_id)
if wait_bucket_ready:
for bucket_id in wait_list:
ag_pipeline.wait_bucket_ready(bucket_id)
def _post_backward(module, *unused):
release_module_parameters(module)
module._training_state = TrainingState.IDLE
def _pre_forward(module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
input_training_state = module._training_state
fsdp_forward_prefetch = True
if input_training_state == TrainingState.PRE_BACKWARD:
# In activation recomputation case, we need to cancel forward prefetch.
fsdp_forward_prefetch = False
else:
module._training_state = TrainingState.FORWARD
if isinstance(module, tuple(fsdp_unit_modules)):
wait_list = []
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.queue_bucket_to_all_gather(
bucket_id,
prefetch=fsdp_forward_prefetch,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
wait_list.append(bucket_id)
for bucket_id in wait_list:
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
if not torch.is_grad_enabled():
return args, kwargs
# Register the backward function to release the parameters.
args_list, args_spec = tree_flatten(args)
kwargs_list, kwargs_spec = tree_flatten(kwargs)
args_kwargs_list = list(args_list) + list(kwargs_list)
inp_tensor_indices: List[int] = []
inp_tensors: List[torch.Tensor] = []
for i, obj in enumerate(args_kwargs_list):
if torch.is_tensor(obj) and obj.requires_grad:
inp_tensor_indices.append(i)
inp_tensors.append(obj)
if len(inp_tensors) == 0:
return args, kwargs
inp_tensors = RegisterFSDPBackwardFunction.apply(
functools.partial(_post_backward, module), *inp_tensors
)
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
args_kwargs_list[inp_tensor_idx] = inp_tensor
args_list = args_kwargs_list[: len(args_list)]
kwargs_list = args_kwargs_list[len(args_list) :]
args = tree_unflatten(args_list, args_spec)
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
return args, kwargs
else:
# All-gather the parameters in every forward pass for FSDP.
for param in module.parameters(recurse=False):
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.queue_bucket_to_all_gather(
bucket_id,
prefetch=fsdp_forward_prefetch,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
for param in module.parameters(recurse=False):
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
return args, kwargs
if self.ddp_config.overlap_param_gather:
fsdp_modules = []
for name, module in root_module.named_modules():
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
self.forward_pre_hooks[f'module {name} parameter all-gather'] = (
module.register_forward_pre_hook(_pre_forward, prepend=True, with_kwargs=True)
)
def _pre_backward(module: nn.Module, *unused):
module._training_state = TrainingState.PRE_BACKWARD
if isinstance(module, tuple(fsdp_unit_modules)):
all_gather_module_parameters(
module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
)
def _root_pre_backward(module: nn.Module, *unused):
"""Marks the module's training state as 'pre_backward' before the
backprop, this function is registered on the root module.
This marking enables us to determine whether forward pass needs to
perform reshard/unshard operations in activation recomputation
scenarios.
"""
for module in root_module.modules():
if isinstance(module, tuple(fsdp_unit_modules)):
module._training_state = TrainingState.PRE_BACKWARD
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
self.all_gather_pipeline.release_bucket(bucket_id)
def _post_forward(module: nn.Module, input: Any, output: Any):
# When composing with module-hook-based activation checkpointing, the
# post-backward hook is responsible for the reshard
if module._training_state == TrainingState.PRE_BACKWARD:
return output
release_module_parameters(module)
module._training_state = TrainingState.IDLE
return output
def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
release_params_fp8_transpose_cache(module.parameters(recurse=False))
if self.data_parallel_sharding_strategy == "optim_grads_params":
fsdp_modules = []
for name, module in root_module.named_modules():
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
self.forward_hooks[f"release module {name} parameters"] = (
module.register_forward_hook(_post_forward, prepend=False)
)
self.backward_pre_hooks[f"all-gather module {name} parameters"] = (
module.register_full_backward_pre_hook(_pre_backward)
)
elif not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
self.forward_hooks[f"remove module {name} fp8 transpose cache"] = (
module.register_forward_hook(
_release_module_fp8_transpose_cache, prepend=False
)
)
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
_root_pre_backward
)
def _make_param_hook(param: torch.nn.Parameter):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
wait_previous_grad_reduce = not self.is_delay_grad_reduce
# FIXME: Use insert forward op to replace grad acc hook, which will
# be lost after parameter data movement. For example, module.cuda()
# will cause the registered grad acc hook to be lost.
def param_hook(*unused):
if param.requires_grad:
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
if self.is_delay_grad_reduce:
param.main_grad.add_(param.grad.data)
else:
param.main_grad.copy_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce and (
not self.is_delay_grad_reduce or self.is_last_microbatch
):
gr_pipeline = self.grad_reduce_pipeline
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
gr_pipeline.place_bucket(bucket_id)
go_rs = gr_pipeline.mark_item_ready(param, async_rs=True)
if go_rs and wait_previous_grad_reduce:
gr_pipeline.wait_for_previous_grad_reduce(
recommeded_queue_capacity=self.suggested_RS_queue_capacity
)
return param_hook
# Register backward gradient accumulation hook for each parameter.
self.grad_accs = []
for param in root_module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer
if param.requires_grad:
if wbuf and wbuf.is_data_distributed:
wbuf.fetch_bucket(and_allocate_params_data=True)
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(_make_param_hook(param))
self.grad_accs.append(grad_acc)
if wbuf and wbuf.is_data_distributed:
wbuf.free_bucket_storage()
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
For grads shard mode there will actually always be gradient sync happening.
"""
# FIXME: Better handling of grads shard mode and no_sync in the training loop so that
# the code doesn't bog down developers.
self.is_last_microbatch = False
try:
yield
finally:
self.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync and self.ddp_config.overlap_param_gather:
# All-gather the first bucket before the forward pass.
self.all_gather_pipeline.queue_bucket_to_all_gather(bucket_id=0, prefetch=False)
else:
self.all_gather_pipeline.reset()
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.all_gather_bucket_and_set_items(
bucket_id=bucket_id, async_op=True
)
group = self.param_and_grad_buffer.parameter_groups[bucket_id]
if group.model_weight_buffer is None:
continue
if group.model_weight_buffer.is_data_distributed:
# If model weight is sharded, we wait for the all-gather to complete and
# then release the bucket immediately to save memory usage.
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if not self.ddp_config.overlap_grad_reduce:
if self.data_parallel_sharding_strategy == "no_shard":
self.param_and_grad_buffer.all_reduce_gradients(
async_op=self.ddp_config.overlap_grad_reduce
)
else:
self.param_and_grad_buffer.reduce_scatter_gradients()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if self.ddp_config.overlap_grad_reduce:
self.grad_reduce_pipeline.wait_for_previous_grad_reduce(0)
self.grad_reduce_pipeline.reset()
else:
self.start_grad_sync()
self.param_and_grad_buffer.update_main_grads()
if self.ddp_config.overlap_param_gather:
self.all_gather_pipeline.reset()
def optimizer_named_parameters(self) -> List[Tuple[str, torch.Tensor]]:
"""
Returns a list of tuples containing the main weights and their corresponding names
for mixed-precision training, to be used by the optimizer for updates.
Returns:
List[Tuple[str, torch.Tensor]]: A list of tuples, where each tuple
contains a main weight tensor and its corresponding name.
"""
return self.param_and_grad_buffer.optimizer_named_parameters
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
self.param_and_grad_buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
self.param_and_grad_buffer.zero_grad()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group(
with_context_parallel=True
)
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
# make a copy of the state_dict to avoid modifying the input state_dict
state_dict = state_dict.copy()
state_dict_extra_states = {}
for key in list(state_dict.keys()):
if key.endswith("_extra_state"):
state_dict_extra_states[key] = state_dict[key]
del state_dict[key]
self.module.load_state_dict(state_dict_extra_states, strict=False)
prefix = "module."
buffer = self.param_and_grad_buffer
for param_groups in buffer.parameter_groups:
wbuf = param_groups.model_weight_buffer
for model_param in wbuf.params:
if is_float8tensor(model_param):
fp8_meta = model_param._fp8_meta['scaling_fwd']
fp8_meta_index = model_param._fp8_meta_index
model_param._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
param_name = f"{buffer.param_to_name[model_param]}"[len(prefix) :]
if param_name in state_dict:
if wbuf and wbuf.is_data_distributed:
model_param.fully_shard_param_local_shard.data.copy_(
state_dict[param_name]
)
else:
model_param.data.copy_(state_dict[param_name])
del state_dict[param_name]
self.module.load_state_dict(state_dict, strict=False)
return
self.module.load_state_dict(state_dict, strict=strict)
class RegisterFSDPBackwardFunction(torch.autograd.Function):
"""
Register a backward function that will be called after the backward pass
of the model. This function is used to release the parameters after the
backward pass.
"""
@staticmethod
def forward(ctx, post_backward, *inputs: torch.Tensor):
"""
Forward pass of the RegisterFSDPBackwardFunction function.
"""
ctx.post_backward = post_backward
return inputs
@staticmethod
def backward(ctx, *grads: torch.Tensor):
"""
Backward pass of the RegisterFSDPBackwardFunction function.
"""
ctx.post_backward()
return (None,) + grads
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import gc
import inspect
import logging
import math
import traceback
import warnings
from collections import namedtuple
from contextlib import ExitStack
from enum import Enum
from typing import Any, List, Optional, Tuple
import torch
from megatron.core import parallel_state
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.fp8_utils import is_float8tensor, quantize_param_fragment
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.utils import is_submodule, is_te_min_version, log_on_each_pipeline_stage
try:
from transformer_engine.pytorch import fp8_model_init
except:
pass
try:
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
except:
pass
logger = logging.getLogger(__name__)
def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
"""Alternate to ``assert`` when in the backward context to print the error
message ``s`` since otherwise, it is swallowed.
"""
if not cond:
print(s)
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_allocated = tensor._typed_storage()._size() == size.numel()
if not already_allocated:
tensor_storage_size = tensor._typed_storage()._size()
_p_assert(
tensor_storage_size == 0,
"Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
)
tensor._typed_storage()._resize_(size.numel())
def _free_storage(tensor: torch.Tensor):
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with torch.no_grad():
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
already_freed = tensor._typed_storage()._size() == 0
if not already_freed:
_p_assert(
tensor.storage_offset() == 0,
"Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
f"storage offset: {tensor.storage_offset()}\n"
f"storage size: {tensor._typed_storage()._size()}\n"
f"tensor shape: {tensor.shape}",
)
tensor._typed_storage()._resize_(0)
TensorItemIndex = namedtuple(
'TensorItemIndex', ['global_data_index', 'size', 'item_id', 'bucket_id', 'shape']
)
BucketIndex = namedtuple('BucketIndex', ['bucket_id', 'global_data_index', 'size', 'items'])
ShardBucketIndex = namedtuple(
'ShardBucketIndex',
['bucket_id', 'global_data_index', 'local_data_index', 'bucket_data_index', 'size'],
)
@dataclasses.dataclass
class BucketingPolicy:
"""
A policy for bucketing in Fully Sharded Data Parallel (FSDP) training.
Attributes:
suggested_bucket_size (int): The suggested size of each bucket in num of elements.
fsdp_unit_modules (list): A list of module classes that are treated as a
single unit for FSDP bucketing.
data_parallel_sharding_strategy (str): The strategy used for sharding
data parallel modules.
Note:
This policy is used to configure the bucketing behavior in FSDP training.
"""
suggested_bucket_size: Optional[int] = 40_000_000
fsdp_unit_modules: List[torch.nn.Module] = dataclasses.field(default_factory=list)
data_parallel_sharding_strategy: str = 'no_shard'
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def build_data_parallel_buffer_index(
elements: List[torch.Size],
data_parallel_rank: int,
data_parallel_world_size: int,
is_data_distributed: bool,
ddp_config: DistributedDataParallelConfig,
bucket_id: int = 0,
) -> Tuple[int, List[tuple], List[tuple], List[tuple]]:
"""
Assuming that all input tensor elements are consecutively compose a global
buffer, give the index range of every tensor, every bucket and every in
bucket local buffer.
Args:
elements (List[torch.Size]): List of input tensor.
data_parallel_rank (int): Rank of the current process in the data parallel group.
data_parallel_world_size (int): World size of the data parallel group.
bucket_id (int, optional): The id of the bucket. Defaults to 0.
Returns:
Tuple[int, List[tuple], List[tuple], List[tuple]]: The index range of every tensor,
every bucket and every in bucket local buffer.
"""
def _pad_if_needed(data_index: int) -> int:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if ddp_config.data_parallel_sharding_strategy != 'no_shard':
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(data_index, math.lcm(data_parallel_world_size, 128))
return data_index
def add_item(item_id, item, bucket, item_index_map, bucket_id):
bucket.append(item)
bucket_size = sum([it.numel() for it in bucket])
item_index_map.append(
TensorItemIndex(
data_index + bucket_size - item.numel(),
item.numel(),
item_id=item_id,
bucket_id=bucket_id,
shape=item,
)
)
item_index_map = []
bucket = []
data_index = 0
for item_id, item in enumerate(elements):
add_item(item_id, item, bucket, item_index_map, bucket_id)
bucket_size = sum([it.numel() for it in bucket])
bucket_size = _pad_if_needed(bucket_size)
bucket_index = BucketIndex(
bucket_id,
data_index,
bucket_size,
items=list(filter(lambda x: x.bucket_id == bucket_id, item_index_map)),
)
shard_size = bucket_index.size // data_parallel_world_size
bucket_data_index = shard_size * data_parallel_rank
global_data_index = bucket_index.global_data_index + bucket_data_index
if is_data_distributed:
shard_bucket_index = ShardBucketIndex(
bucket_id, global_data_index, 0, bucket_data_index, shard_size
)
else:
shard_bucket_index = ShardBucketIndex(
bucket_id, global_data_index, global_data_index, bucket_data_index, shard_size
)
return item_index_map, bucket_index, shard_bucket_index
@dataclasses.dataclass
class Bucket:
"""
A container for holding data in Fully Sharded Data Parallel (FSDP) training.
Attributes:
data (torch.Tensor): A tensor containing the data elements
grouped together in a bucket.
data_operation_event (Optional[torch.cuda.Event]): An optional CUDA event
used to synchronize data operations.
status (Any): An optional status object used to track the state of the bucket.
Note:
Buckets are used to optimize communication in FSDP training by
grouping small tensors together.
"""
data: torch.Tensor
data_operation_event: Optional[torch.cuda.Event] = None
status: Any = None
class TemporaryBucketAllocator:
"""
A utility class for managing temporary buckets (buffers) used in FSDP
operations like parameters unshard and gradients reduction.
This allocator handles the dynamic allocation and deallocation of temporary memory buffers
needed during FSDP (Fully Sharded Data Parallel) operations, particularly for parameters
unshard and gradients reduction. It helps optimize memory usage by allowing temporary
buckets to be released when no longer needed.
Key Features:
- Dynamic allocation of temporary buckets for FSDP operations
- Memory-efficient management of temporary buffers
- Support for both parameters unshard and gradients reduction operations
- Automatic cleanup of unused buckets to save memory
Usage:
```python
# Create an allocator instance
allocator = TemporaryBucketAllocator(name="gpt_parameters")
# Allocate a temporary bucket
temp_bucket = allocator.allocate(size=1024, dtype=torch.float32)
# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...
# Free the bucket when done
allocator.free(temp_bucket)
```
Note:
It's important to release temporary buckets after use to prevent memory leaks
and optimize memory usage during training.
"""
def __init__(self):
self.buckets = {}
def allocate(
self, bucket_id: int, size: int, dtype: torch.dtype, device: torch.device
) -> Bucket:
"""
allocate a temporary bucket.
"""
if bucket_id not in self.buckets:
self.buckets[bucket_id] = Bucket(data=torch.empty(size, dtype=dtype, device=device))
return self.buckets[bucket_id]
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
if bucket_id in self.buckets:
_free_storage(self.buckets[bucket_id].data)
del self.buckets[bucket_id]
class StorageResizeBasedBucketAllocator(TemporaryBucketAllocator):
"""
A specialized temporary bucket allocator that resizes the storage of temporary buckets
based on the required size.
"""
def __init__(self):
self.buckets = {} # {bucket_id: Bucket}
def allocate(
self, bucket_id: int, size: int, dtype: torch.dtype, device: torch.device
) -> Bucket:
"""
allocate a temporary bucket.
"""
if bucket_id not in self.buckets:
self.buckets[bucket_id] = Bucket(data=torch.empty(size, dtype=dtype, device=device))
bucket = self.buckets[bucket_id]
_alloc_storage(bucket.data, torch.Size([size]))
return bucket
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
if bucket_id in self.buckets:
_free_storage(self.buckets[bucket_id].data)
class RotaryBucketAllocator(TemporaryBucketAllocator):
"""A specialized temporary bucket allocator that implements a circular buffer recycling strategy
to minimize memory fragmentation in FSDP operations.
RotaryBucketAllocator extends TemporaryBucketAllocator by maintaining a limited pool of
pre-allocated buffers that are reused in a circular manner. This approach helps prevent
memory fragmentation that typically occurs with frequent allocation and deallocation of
temporary buffers during FSDP operations.
Key Features:
- Circular buffer recycling strategy for memory efficiency
- Reduced memory fragmentation compared to dynamic allocation
- Pre-allocated buffer pool for faster access
- Automatic buffer reuse without explicit deallocation
Usage:
```python
# Create a rotary allocator
allocator = RotaryBucketAllocator(name="gpt_parameters")
# Get a temporary buffer from the pool
temp_bucket = allocator.allocate(size=1024, dtype=torch.float32)
# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...
# Free the bucket when done, make it in idle buffer pool
allocator.free(temp_bucket)
```
"""
def __init__(self, name: str):
self.name = name
self.num_global_buffer = 0
self.idle_buffer = [] # [buffer_id]
self.using_buffer = {} # {bucket_id: buffer_id}
self.buckets = {}
def allocate(
self, bucket_id: int, size: int, dtype: torch.dtype, device: torch.device
) -> Bucket:
"""
allocate a temporary bucket.
"""
def _get_global_buffer(buffer_id: int):
return parallel_state.get_global_memory_buffer().get_tensor(
[size], dtype=dtype, name=self._get_gbuf_name(buffer_id)
)
if bucket_id in self.using_buffer:
buffer_id = self.using_buffer[bucket_id]
return Bucket(data=_get_global_buffer(buffer_id))
if len(self.idle_buffer) == 0:
# allocate new buffer
buffer_id = self.num_global_buffer
self.num_global_buffer += 1
self.idle_buffer.append(buffer_id)
buffer_id = self.idle_buffer.pop(0)
self.using_buffer[bucket_id] = buffer_id
return Bucket(data=_get_global_buffer(buffer_id))
def _get_gbuf_name(self, buffer_id: int):
return f"{self.name}_{buffer_id}"
def free(self, bucket_id: int):
"""
free a temporary bucket.
"""
if bucket_id in self.using_buffer:
buffer_id = self.using_buffer.pop(bucket_id)
self.idle_buffer.append(buffer_id)
class DataParallelBuffer:
"""
A class that manages the data parallel buffer for Fully Sharded Data Parallel (FSDP) training.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
params: List[torch.nn.Parameter],
is_data_distributed: bool,
bucket_id: int,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
temporary_bucket_allocator: Optional[TemporaryBucketAllocator] = None,
init_meta_only: bool = False,
is_dtype_float8: bool = False,
gradient_scaling_factor: Optional[float] = None,
) -> None:
self.ddp_config = ddp_config
self.params = params
_param_dtype = {p.dtype for p in self.params}
assert len(_param_dtype) == 1, f'params have different dtypes: {_param_dtype}'
self.is_data_distributed = is_data_distributed
self.bucket_id = bucket_id
self.dtype = dtype if dtype else next(iter(_param_dtype))
self.device = device
self.data_parallel_group = data_parallel_group
self.dp_rank = torch.distributed.get_rank(group=self.data_parallel_group)
self.dp_world_size = torch.distributed.get_world_size(group=self.data_parallel_group)
self.temporary_bucket_allocator = (
temporary_bucket_allocator if temporary_bucket_allocator else TemporaryBucketAllocator()
)
self.is_dtype_float8 = is_dtype_float8
self.gradient_scaling_factor = gradient_scaling_factor
(self.item_index_map, self.bucket_index, self.shard_bucket_index) = (
build_data_parallel_buffer_index(
[p.shape for p in self.params],
self.dp_rank,
self.dp_world_size,
is_data_distributed,
ddp_config,
bucket_id=bucket_id,
)
)
self.data_size = (
self.bucket_index.size if not is_data_distributed else self.shard_bucket_index.size
)
if init_meta_only:
self.data = None
else:
self.data = torch.empty(self.data_size, dtype=self.dtype, device=device)
self.param_idx = {p: i for i, p in enumerate(self.params)}
self.placeholder_bucket = None
self.placeholder_items = {}
def fetch_bucket(
self, dtype: Optional[torch.dtype] = None, and_allocate_params_data: bool = False
) -> Bucket:
"""
Fetch a communication buffer for data-parallel operations.
The size of the bucket is defined by the `DataParallelBuffer` instance.
If `and_allocate_params_data` is True, this method resets the parameter
data stored in the `DataParallelBuffer` instance.
Args:
dtype (Optional[torch.dtype], optional): The data type of the tensor
to fetch a buffer for. Defaults to None.
and_allocate_params_data (bool, optional): Whether to allocate and
reset parameter data. Defaults to False.
Returns:
Bucket: The communication buffer for the specified data type.
"""
if dtype is None:
dtype = self.dtype
bucket_index = self.bucket_index
if not self.is_data_distributed and dtype == self.dtype:
bucket = Bucket(
data=self.data[
bucket_index.global_data_index : bucket_index.global_data_index
+ bucket_index.size
]
)
else:
bucket = self.temporary_bucket_allocator.allocate(
bucket_id=bucket_index.bucket_id,
size=bucket_index.size,
dtype=dtype,
device=self.device,
)
if and_allocate_params_data:
for p in self.params:
item_id = self.param_idx[p]
if is_float8tensor(p):
p._data = self.get_item_from_bucket(bucket, item_id).view(p.shape)
else:
p.data = self.get_item_from_bucket(bucket, item_id).view(p.shape)
return bucket
def free_bucket_storage(self, and_free_params_data: bool = False):
"""
Release the storage of a temporary communication bucket.
If the bucket is temporary, this method frees its storage.
If `and_free_params_data` is True, this method also releases the storage
of the parameter data stored in the `DataParallelBuffer` instance.
Args:
and_free_params_data (bool, optional): Whether to also release the
storage of the parameter data. Defaults to False.
Returns:
None
"""
if not self.is_data_distributed:
return
self.temporary_bucket_allocator.free(self.bucket_index.bucket_id)
if and_free_params_data:
if self.placeholder_bucket is None:
self.placeholder_bucket = Bucket(
data=torch.empty(self.bucket_index.size, dtype=self.dtype, device=self.device)
)
for p in self.params:
item_id = self.param_idx[p]
self.placeholder_items[item_id] = self.get_item_from_bucket(
self.placeholder_bucket, item_id
).view(p.shape)
_free_storage(self.placeholder_bucket.data)
for p in self.params:
item_id = self.param_idx[p]
if is_float8tensor(p):
p._data = self.placeholder_items[item_id]
else:
p.data = self.placeholder_items[item_id]
def _get_item_slice_in_shard(self, item_id: int) -> Tuple[int, int]:
item_index = self.item_index_map[item_id]
shard_bucket_index = self.shard_bucket_index
item_global_start = item_index.global_data_index
item_global_end = item_index.global_data_index + item_index.size
shard_bucket_start = shard_bucket_index.global_data_index
shard_bucket_end = shard_bucket_index.global_data_index + shard_bucket_index.size
if item_global_start > shard_bucket_end or item_global_end < shard_bucket_start:
return (0, 0)
start = max(item_global_start, shard_bucket_start) - item_global_start
end = min(item_global_end, shard_bucket_end) - item_global_start
return (start, end)
# pylint: disable=missing-function-docstring
def locate_item_in_global_item(self, item_id: int) -> Tuple[int, int]:
item_index = self.item_index_map[item_id]
if not self.is_data_distributed:
return (0, item_index.size)
slice_start, slice_end = self._get_item_local_shard_index(item_id)
if slice_start == slice_end:
return (0, 0)
local_shard_index_to_global_index_offset = (
self.shard_bucket_index.global_data_index - self.shard_bucket_index.local_data_index
)
slice_start += local_shard_index_to_global_index_offset
slice_end += local_shard_index_to_global_index_offset
return (
slice_start - item_index.global_data_index,
slice_end - item_index.global_data_index,
)
def _get_item_local_shard_index(self, item_id: int) -> Tuple[int, int]:
slice_start, slice_end = self._get_item_slice_in_shard(item_id)
if slice_start == slice_end:
return (0, 0)
item_index = self.item_index_map[item_id]
shard_bucket_index = self.shard_bucket_index
offset = (
item_index.global_data_index
- shard_bucket_index.global_data_index
+ shard_bucket_index.local_data_index
)
return (offset + slice_start, offset + slice_end)
def _get_item_local_index(self, item_id: int) -> Tuple[int, int]:
if not self.is_data_distributed:
item_index = self.item_index_map[item_id]
return (item_index.global_data_index, item_index.global_data_index + item_index.size)
return self._get_item_local_shard_index(item_id)
def set_item(self, item_id: int, item_data: torch.Tensor) -> None:
"""
Update a tensor item managed by the `DataParallelBuffer` instance.
The storage of the item is mapped to the communication bucket.
This method updates the item data and ensures consistency with the bucket.
Args:
item_id (int): The ID of the tensor item to update.
item_data (torch.Tensor): The new data for the tensor item.
Returns:
None
"""
if self.is_data_distributed:
slice_start, slice_end = self._get_item_slice_in_shard(item_id)
item_data = item_data.flatten()[slice_start:slice_end]
local_index_start, local_index_end = self._get_item_local_index(item_id)
shard = self.data[local_index_start:local_index_end]
if shard.numel() > 0:
shard.data.copy_(item_data.flatten())
def get_item(self, item_id: int, only_shard: bool = False) -> torch.Tensor:
"""
Retrieve a tensor item managed by the `DataParallelBuffer` instance.
The storage of the item is mapped to the communication bucket.
If `only_shard` is True, returns only the shard of the item corresponding
to the current process.
Otherwise, returns the entire item.
Args:
item_id (int): The ID of the tensor item to retrieve.
only_shard (bool, optional): Whether to return only the shard of the
item. Defaults to False.
Returns:
torch.Tensor: The retrieved tensor item.
"""
if only_shard:
start, end = self._get_item_local_shard_index(item_id)
else:
start, end = self._get_item_local_index(item_id)
return self.data[start:end]
def get_item_from_bucket(self, bucket: Bucket, item_id: int):
"""get item from bucket."""
item_index = self.item_index_map[item_id]
bucket_index = self.bucket_index
start_index = item_index.global_data_index - bucket_index.global_data_index
end_index = start_index + item_index.size
item = bucket.data[start_index:end_index]
return item
def get_shard_from_bucket(self, bucket: Bucket):
"""Get the local sharding of the bucket."""
shard_bucket_index = self.shard_bucket_index
offset = shard_bucket_index.bucket_data_index
shard_size = shard_bucket_index.size
shard = bucket.data[offset : offset + shard_size]
return shard
def get_shard_from_local_buffer(self) -> torch.Tensor:
"""Get the local sharding of the bucket."""
index = self.shard_bucket_index
return self.data[index.local_data_index : index.local_data_index + index.size]
@dataclasses.dataclass
class ParameterGroup:
"""
A group of model parameters with associated metadata for data-parallel training.
This dataclass encapsulates a list of PyTorch parameters and additional information
necessary for managing data-parallel operations, such as data type, gradient requirements,
and buffer assignments.
"""
params: List[torch.nn.Parameter]
dtype: Optional[torch.dtype] = None
is_expert_param: bool = False
requires_grad: Optional[bool] = None
fsdp_unit_id: Optional[int] = None
data_parallel_world_size: Optional[int] = None
model_weight_buffer: Optional[DataParallelBuffer] = None
main_weight_buffer: Optional[DataParallelBuffer] = None
main_grad_buffer: Optional[DataParallelBuffer] = None
def _get_parameter_groups(
module: torch.nn.Module, policy: BucketingPolicy, meta_device_init_fp8_params: dict
):
"""
Get the parameter group for the given module and parameters.
"""
param_to_name = {p: name for name, p in module.named_parameters()}
fsdp_units = []
if policy.fsdp_unit_modules:
param_to_id = {}
for i, p in enumerate(module.parameters()):
param_to_id[p] = i
fsdp_modules = []
for m in module.modules():
# Skip nested FSDP module.
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(m, tuple(policy.fsdp_unit_modules)):
fsdp_units.append([param_to_name[p] for p in m.parameters()])
fsdp_modules.append(m)
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and policy.data_parallel_sharding_strategy != "no_shard"
)
is_expert_parameter = lambda p: not getattr(p, 'allreduce', True)
# Step 1: Group the parameters according to their execution order and attributes.
parameter_groups = []
for name, param in module.named_parameters():
param_attrs = dict(
dtype=(
"float8"
if is_float8tensor(param) or meta_device_init_fp8_params.get(name, False)
else param.dtype
),
is_expert_param=is_expert_parameter(param),
requires_grad=param.requires_grad,
fsdp_unit_id=None,
)
for fsdp_unit_id, fsdp_unit in enumerate(fsdp_units):
if name in fsdp_unit:
param_attrs["fsdp_unit_id"] = fsdp_unit_id
break
found_group = False
for param_group in parameter_groups:
group_attrs = {
key: value for key, value in param_group.__dict__.items() if key in param_attrs
}
if group_attrs == param_attrs:
param_group.params.append(param)
found_group = True
break
if not found_group:
parameter_groups.append(ParameterGroup([param], **param_attrs))
# Step 2: Bucket the parameters based on the guide bucket size.
suggested_bucket_size = policy.suggested_bucket_size
bucket_groups = []
for group in parameter_groups:
bucket = []
basic_attrs = {
key: value
for key, value in group.__dict__.items()
if key in ['dtype', 'is_expert_param', 'requires_grad', 'fsdp_unit_id']
}
for param in group.params:
if _does_param_require_new_bucket(param):
if len(bucket) > 0:
bucket_groups.append(ParameterGroup(bucket, **basic_attrs))
bucket_groups.append(ParameterGroup([param], **basic_attrs))
bucket = []
continue
bucket.append(param)
if (
group.fsdp_unit_id is None
and suggested_bucket_size
and sum([p.numel() for p in bucket]) >= suggested_bucket_size
):
bucket_groups.append(ParameterGroup(bucket, **basic_attrs))
bucket = []
continue
if bucket:
bucket_groups.append(ParameterGroup(bucket, **basic_attrs))
param_to_param_group = {}
for group_id, group in enumerate(bucket_groups):
for param in group.params:
param_to_param_group[param] = group_id
# Log buckets for all PP stages.
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
):
log_strs = []
log_strs.append(f'Number of parameter groups for FSDP: {len(bucket_groups)}')
for index, group in enumerate(bucket_groups):
numel = 0
for param in group.params:
numel += param.numel()
log_strs.append(
f"Params for group {index+1} ({numel} elements, dtype {group.dtype}, "
f"has_weight_buffer: {group.model_weight_buffer is not None}, "
f"has_grad_buffer: {group.main_grad_buffer is not None}, "
f"has_main_weight_buffer: {group.main_weight_buffer is not None}):"
)
for param in group.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
return (bucket_groups, fsdp_units, param_to_param_group)
class ParamAndGradBuffer:
"""A class that manages parameter grouping, buffer allocation, and
communication operations for data-parallel distributed training.
This class provides functionality to:
1. Group parameters based on their data types and communication group sizes
2. Create contiguous buffers for model weights, gradients, and high-precision
main weights
3. Handle parameter unsharding, gradient reduction, and weight
synchronization operations
Key Features:
- Efficient parameter grouping based on data types and communication patterns
- Memory-efficient contiguous buffer allocation
- Support for mixed-precision training with main weights
- Distributed operations including parameters all-gather and gradients
reduce-scatter/all-reduce
- Synchronized weight updates between model and main weights
Note:
This class is designed for distributed training scenarios where efficient
parameter management and communication are crucial for performance.
Args:
ddp_config (DistributedDataParallelConfig): The distributed data parallel
configuration.
module (torch.nn.Module): The module whose parameters are to be grouped
and flatten.
bucketing_policy (BucketingPolicy): The bucketing policy.
data_parallel_group (torch.distributed.ProcessGroup): The data parallel group.
expert_data_parallel_group (Optional[torch.distributed.ProcessGroup]):
The expert data parallel group.
preserve_fp32_weights (bool): Whether to preserve FP32 weights.
grad_reduce_in_fp32 (bool): Whether to reduce gradients in FP32.
gradient_scaling_factor (Optional[float]): The gradient scaling factor.
expert_gradient_scaling_factor (Optional[float]): The expert gradient
scaling factor.
device (torch.device): The parameter and gradient buffer device.
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad (bool):
Whether to only create the gradient buffer and main weight buffer
for parameters that require gradients. Default is True.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
bucketing_policy: BucketingPolicy,
data_parallel_group: torch.distributed.ProcessGroup,
expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
preserve_fp32_weights: bool = True,
grad_reduce_in_fp32: bool = True,
gradient_scaling_factor: Optional[float] = None,
expert_gradient_scaling_factor: Optional[float] = None,
device: torch.device = torch.device('cuda'),
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad: bool = True,
reset_parameters_for_meta_device_init_module: bool = False,
):
self.ddp_config = ddp_config
self.module = module
self.bucketing_policy = bucketing_policy
self.param_to_name = {p: name for name, p in self.module.named_parameters()}
self.preserve_fp32_weights = preserve_fp32_weights
self.grad_reduce_in_fp32 = grad_reduce_in_fp32
self.data_parallel_group = data_parallel_group
self.expert_data_parallel_group = expert_data_parallel_group
self.params = list(module.parameters())
self.gradient_scaling_factor = gradient_scaling_factor
self.expert_gradient_scaling_factor = expert_gradient_scaling_factor
self.device = device
self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad = (
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
)
self.reset_parameters_for_meta_device_init_module = (
reset_parameters_for_meta_device_init_module
)
# Mark fp8 param.
meta_device_init_fp8_params = {}
if reset_parameters_for_meta_device_init_module:
for m in module.modules():
if not isinstance(m, TransformerEngineBaseModule):
continue
for name, param in m.named_parameters(recurse=False):
# The fp8 param initialized from the meta device may NOT be
# an fp8 tensor, according to the internal logic of the TE
# to determine whether this parameter is fp8 or not.
fp8_meta_index = m.param_init_meta[name].fp8_meta_index
if m.primary_weights_in_fp8 and fp8_meta_index is not None:
meta_device_init_fp8_params[self.param_to_name[param]] = True
# Get the parameter groups.
(self.parameter_groups, self.fsdp_units, self.param_to_param_group) = _get_parameter_groups(
module, bucketing_policy, meta_device_init_fp8_params
)
self._init_each_parameter_group_buffers(meta_device_init_fp8_params)
# Initialize the optimizer named parameters.
self.optimizer_named_parameters = self._init_optimizer_named_parameters()
def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
"""
Initialize the buffers for each parameter group.
"""
data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy
if data_parallel_sharding_strategy == 'no_shard':
is_model_weight_buffer_distributed = False
is_main_weight_buffer_distributed = False
is_grad_buffer_distributed = False
elif data_parallel_sharding_strategy == 'optim':
is_model_weight_buffer_distributed = False
is_main_weight_buffer_distributed = True
is_grad_buffer_distributed = False
elif data_parallel_sharding_strategy == 'optim_grads':
is_model_weight_buffer_distributed = False
is_main_weight_buffer_distributed = True
is_grad_buffer_distributed = True
elif data_parallel_sharding_strategy == 'optim_grads_params':
is_model_weight_buffer_distributed = True
is_main_weight_buffer_distributed = True
is_grad_buffer_distributed = True
else:
raise ValueError(
f'Invalid data_parallel_sharding_strategy: {data_parallel_sharding_strategy}'
)
self.memory_allocator_for_model_weight_buffer = StorageResizeBasedBucketAllocator()
self.buffer_all_in_one = True
preserve_fp32_weights = self.preserve_fp32_weights
grad_reduce_in_fp32 = self.grad_reduce_in_fp32
buffer_size = {torch.float32: 0, torch.float16: 0, torch.bfloat16: 0, "float8": 0}
for group_id, group in enumerate(self.parameter_groups):
dp_group = (
self.data_parallel_group
if not group.is_expert_param
else self.expert_data_parallel_group
)
group.data_parallel_world_size = torch.distributed.get_world_size(group=dp_group)
gradient_scaling_factor = (
self.gradient_scaling_factor
if not group.is_expert_param
else self.expert_gradient_scaling_factor
)
one_param = group.params[0]
is_dtype_float8 = is_float8tensor(one_param) or meta_device_init_fp8_params.get(
self.param_to_name[one_param], False
)
if is_dtype_float8:
param_dtype = torch.uint8
grad_dtype = torch.bfloat16
else:
param_dtype = group.params[0].dtype
grad_dtype = param_dtype
should_create_grad_buffer_or_main_weight_buffer = (
not self.only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
or group.requires_grad
)
# Initialize the model weight buffer.
if data_parallel_sharding_strategy != 'no_shard':
group.model_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_model_weight_buffer_distributed
and group.data_parallel_world_size > 1,
dtype=param_dtype,
device=self.device,
data_parallel_group=dp_group,
init_meta_only=True,
is_dtype_float8=is_dtype_float8,
temporary_bucket_allocator=self.memory_allocator_for_model_weight_buffer,
bucket_id=group_id,
)
# Initialize the main weight buffer.
if should_create_grad_buffer_or_main_weight_buffer and preserve_fp32_weights:
group.main_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_main_weight_buffer_distributed
and group.data_parallel_world_size > 1,
dtype=torch.float32,
device=self.device,
data_parallel_group=dp_group,
init_meta_only=True,
bucket_id=group_id,
)
# Initialize the main grad buffer.
if should_create_grad_buffer_or_main_weight_buffer:
group.main_grad_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_grad_buffer_distributed
and group.data_parallel_world_size > 1,
dtype=torch.float32 if grad_reduce_in_fp32 else grad_dtype,
device=self.device,
data_parallel_group=dp_group,
init_meta_only=True,
is_dtype_float8=not grad_reduce_in_fp32 and grad_dtype is torch.uint8,
gradient_scaling_factor=gradient_scaling_factor,
bucket_id=group_id,
)
if grad_reduce_in_fp32:
buffer_size[torch.float32] += group.main_grad_buffer.data_size
elif group.main_grad_buffer.is_dtype_float8:
buffer_size["float8"] += group.main_grad_buffer.data_size
else:
buffer_size[group.main_grad_buffer.dtype] += group.main_grad_buffer.data_size
reset_context_args = {"init_param_with_fp8": self.ddp_config.fp8_param_gather}
module_reset_flag = {}
if self.reset_parameters_for_meta_device_init_module:
self.param_to_direct_module = {}
for name, m in self.module.named_modules():
for p in m.parameters(recurse=False):
self.param_to_direct_module[p] = (name, m)
meta_params_numel = 0
cuda_params_numel = 0
cpu_params_numel = 0
for group in self.parameter_groups:
for p in group.params:
if p.is_meta:
meta_params_numel += p.numel()
elif p.device.type == 'cuda':
cuda_params_numel += p.numel()
else:
cpu_params_numel += p.numel()
log_str = (
f"Meta params numel: {meta_params_numel / 1_000_000:.2f} M, "
f"CUDA params numel: {cuda_params_numel / 1_000_000:.2f} M, "
f"CPU params numel: {cpu_params_numel / 1_000_000:.2f} M"
)
log_on_each_pipeline_stage(logger, logging.INFO, log_str)
# Initialize the model weight buffer data of each parameter group.
for group in self.parameter_groups:
wbuf = group.model_weight_buffer
if wbuf:
wbuf.data = torch.empty(wbuf.data_size, dtype=wbuf.dtype, device=self.device)
bucket = wbuf.fetch_bucket()
mbuf = group.main_weight_buffer
if mbuf:
mbuf.data = torch.empty(mbuf.data_size, dtype=mbuf.dtype, device=self.device)
for item_id, p in enumerate(group.params):
if wbuf:
if self.reset_parameters_for_meta_device_init_module and p.is_meta:
m_name, m = self.param_to_direct_module[p]
if not module_reset_flag.get(m_name, False) and hasattr(
m, "reset_parameters"
):
old_params = list(m.parameters(recurse=False))
# If the GPU memory over threshold, empty cache to leave
# some memory for initialization of the model on the
# CUDA device.
if check_gpu_memory(threshold=0.5):
gc.collect()
torch.cuda.empty_cache()
m.to_empty(device=self.device, recurse=False)
if is_te_min_version("0.9.0") and not isinstance(
m, TransformerEngineBaseModule
):
reset_context_args["with_cuda_rng_tracker"] = True
with ResetParametersContext(**reset_context_args):
m.reset_parameters()
module_reset_flag[m_name] = True
new_params = list(m.parameters(recurse=False))
self._reset_parameters(old_params, new_params)
p = group.params[item_id]
assert not p.is_meta, (self.param_to_name[p], module_reset_flag)
wbuf.set_item(item_id, p.data)
# reset the parameter data to the buffer
old_param_data = p.data
new_param_data = wbuf.get_item_from_bucket(bucket, item_id).view(p.shape)
if is_float8tensor(p):
p._data = new_param_data
else:
p.data = new_param_data
assert old_param_data._base is None
p.data.detach().copy_(old_param_data)
del old_param_data
if mbuf:
if hasattr(p, 'get_high_precision_init_val'):
mbuf.set_item(item_id, p.get_high_precision_init_val())
p.clear_high_precision_init_val()
else:
mbuf.set_item(item_id, p)
if wbuf and wbuf.is_data_distributed:
"""
When MCore Custom FSDP `optim_grads_params` is enabled,
it is necessary to save the tensor local shard. This local shard is
accessible through the `fully_shard_param_local_shard`
attribute of the tensor.
This attribute contains the local shard of the fully
sharded parameter, which is essential for correctly
saving and loading the model state when using
`optim_grads_params` with FSDP.
Example:
>>> # Assuming `tensor` is a fully sharded parameter
>>> local_shard = tensor.fully_shard_param_local_shard
>>> # Save the local shard as needed
"""
local_shard = wbuf.get_item(item_id, only_shard=True)
local_shard.fsdp_shard_orig_param = p
p.fully_shard_param_local_shard = local_shard
p.fully_shard_param_local_index = wbuf.locate_item_in_global_item(item_id)
def disable_shard_param_to_function(*unused):
"""Prevents users from accessing the 'to' operation
on parameters after sharding.
This restriction helps maintain data integrity and
proper sharding behavior by disabling direct 'to'
device/dtype operations on sharded parameters.
"""
raise RuntimeError(
"Your model is wrapped by MCore Custom FSDP. All "
"parameter dtypes and devices must be set before FSDP "
"wrapping. After FSDP wrapping, parameter storage "
"is sharded and you cannot modify parameter "
"dtypes or devices."
)
setattr(p, 'to', disable_shard_param_to_function)
def disable_shard_param_cpu_function(*unused):
warnings.warn(
"The parameters are sharded by custom fsdp, "
"and no actual cpu operation is performed."
)
return torch.empty([], device='cpu')
setattr(p, 'cpu', disable_shard_param_cpu_function)
if wbuf and wbuf.is_data_distributed:
wbuf.free_bucket_storage()
# Allocate the main_weight buffer and main_grad buffer data in one buffer.
if self.buffer_all_in_one:
self.buffer = {
torch.float32: torch.empty(
buffer_size[torch.float32], dtype=torch.float32, device=self.device
),
torch.float16: torch.empty(
buffer_size[torch.float16], dtype=torch.float16, device=self.device
),
torch.bfloat16: torch.empty(
buffer_size[torch.bfloat16], dtype=torch.bfloat16, device=self.device
),
"float8": torch.empty(buffer_size["float8"], dtype=torch.uint8, device=self.device),
}
offset = {torch.float32: 0, torch.float16: 0, torch.bfloat16: 0, "float8": 0}
def _alloc(dtype, size):
if self.buffer_all_in_one:
if dtype == torch.uint8:
dtype = "float8"
data = self.buffer[dtype][offset[dtype] : offset[dtype] + size]
offset[dtype] += size
return data
return torch.empty(size, dtype=dtype, device=self.device)
# Initialize the main grad buffer data of each parameter group.
for group in self.parameter_groups:
gbuf = group.main_grad_buffer
if not gbuf:
continue
gbuf.data = _alloc(gbuf.dtype, gbuf.data_size)
gbuf.data.zero_()
for item_id, p in enumerate(group.params):
p.fsdp_managed_main_grad = gbuf.get_item(item_id)
p._gbuf = gbuf
p._item_id = item_id
def main_grad_getter(p):
# Make sure main_grad memory storage ready.
bucket = p._gbuf.fetch_bucket()
gbuf = p._gbuf
item_id = p._item_id
if bucket.status == GradBucketStatus.GRAD_REDUCING:
if bucket.data_operation_event:
bucket.data_operation_event.wait()
bucket.data_operation_event = None
# Here it is assumed that main_grad is taken out and do
# gradient accumulation and should not be freed up before
# gradient reduction.
bucket.status = GradBucketStatus.GRAD_ACCUMULATING
return gbuf.get_item_from_bucket(bucket, item_id).view(p.shape)
setattr(p.__class__, 'main_grad', property(main_grad_getter))
if gbuf.is_data_distributed:
gbuf.free_bucket_storage()
gc.collect()
torch.cuda.empty_cache()
def _reset_parameters(self, old_params, new_params):
assert len(old_params) == len(new_params)
param_map = {}
for old_param, new_param in zip(old_params, new_params):
param_map[old_param] = new_param
self.param_to_name[new_param] = self.param_to_name[old_param]
del self.param_to_name[old_param]
self.param_to_param_group[new_param] = self.param_to_param_group[old_param]
del self.param_to_param_group[old_param]
self.param_to_direct_module[new_param] = self.param_to_direct_module[old_param]
del self.param_to_direct_module[old_param]
for item_id, p in enumerate(self.params):
if p in param_map:
new_p = param_map[p]
self.params[item_id] = new_p
for group in self.parameter_groups:
for item_id, p in enumerate(group.params):
if p not in param_map:
continue
new_p = param_map[p]
group.params[item_id] = new_p
for buf in [
group.model_weight_buffer,
group.main_weight_buffer,
group.main_grad_buffer,
]:
if buf is None:
continue
buf.param_idx[new_p] = buf.param_idx[p]
del buf.param_idx[p]
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
for group in self.parameter_groups:
if group.main_grad_buffer is None:
continue
group.main_grad_buffer.data *= scaling_factor
def zero_grad(self):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation
for the next iteration of training.
"""
for _, param in self.optimizer_named_parameters:
if param.grad is not None and param.grad._base is None:
# For tensors that are not referenced, trying to use storage
# resize to make memory free immediately.
_free_storage(param.grad)
param.grad = None
for group in self.parameter_groups:
if group.main_grad_buffer is None:
continue
group.main_grad_buffer.data.zero_()
def _init_optimizer_named_parameters(self) -> List[Tuple[str, torch.nn.Parameter]]:
named_parameters = []
for pg in self.parameter_groups:
if pg.main_grad_buffer is None:
continue
optimizer_state_is_shard = pg.main_grad_buffer.is_data_distributed or (
pg.main_weight_buffer and pg.main_weight_buffer.is_data_distributed
)
for item_id, orig_param in enumerate(pg.params):
if pg.main_weight_buffer:
param = pg.main_weight_buffer.get_item(
item_id, only_shard=optimizer_state_is_shard
)
elif pg.model_weight_buffer:
param = pg.model_weight_buffer.get_item(
item_id, only_shard=optimizer_state_is_shard
)
else:
param = orig_param
def set_param_attribute_closure(param, orig_param):
def set_param_attribute():
for attr_name in [
'requires_grad',
'sequence_parallel',
'shared',
'tensor_model_parallel',
'partition_dim',
'partition_stride',
'is_embedding_or_output_parameter',
]:
if hasattr(orig_param, attr_name):
setattr(param, attr_name, getattr(orig_param, attr_name))
return set_param_attribute
setattr(param, 'reset_attribute', set_param_attribute_closure(param, orig_param))
setattr(param, 'orig_param', orig_param)
param.reset_attribute()
named_parameters.append((self.param_to_name[orig_param], param))
return named_parameters
def update_main_grads(self):
"""Update the main gradients for preparing the optimizer step."""
for _, param in self.optimizer_named_parameters:
param.reset_attribute()
orig_param = param.orig_param
group = self.parameter_groups[self.param_to_param_group[orig_param]]
item_id = group.main_grad_buffer.param_idx[orig_param]
optimizer_grad = group.main_grad_buffer.get_item(
item_id, only_shard=group.main_weight_buffer.is_data_distributed
)
setattr(
param,
'grad',
optimizer_grad.to(param.dtype) if optimizer_grad.numel() > 0 else None,
)
@property
def num_buckets(self):
"""Return the number of buckets."""
return len(self.parameter_groups)
@torch.no_grad()
def copy_main_weights_to_model_weights(self):
"""Update the model weights from the main weights."""
for pg in self.parameter_groups:
mbuf = pg.main_weight_buffer
wbuf = pg.model_weight_buffer
if mbuf is None:
continue
for param in pg.params:
item_id = mbuf.param_idx[param]
if wbuf:
if wbuf.is_data_distributed or mbuf.is_data_distributed:
model_param = wbuf.get_item(item_id, only_shard=True)
main_weight = mbuf.get_item(item_id, only_shard=True)
else:
model_param = wbuf.get_item(item_id)
main_weight = mbuf.get_item(item_id)
else:
assert not mbuf.is_data_distributed
model_param = param
main_weight = pg.main_weight_buffer.get_item(item_id)
if model_param.numel() == 0:
continue
if is_float8tensor(param):
# 1. When "--fp8-param-gather" is disabled, the main param
# is first casted to BF16/FP16, and then casted to FP8, so
# the amax_history is calculated using BF16/FP16 param.
# 2. When "--fp8-param-gather" is enabled, we can cast the
# FP32 main param to FP8 directly, which results in slightly
# different results with higher performance. In theory, this
# does not affect convergence.
# TODO: The following code maintains the logic of the point-1
# above. It can be deleted if it is not necessary.
main_weight = main_weight.to(param.dtype)
quantize_param_fragment(input_=main_weight, out=model_param, param=param)
else:
model_param.data.copy_(main_weight.view(model_param.shape))
@torch.no_grad()
def copy_model_weights_to_main_weights(self):
"""Copy the model weights to the main weights."""
for group in self.parameter_groups:
mbuf = group.main_weight_buffer
if mbuf is None:
continue
wbuf = group.model_weight_buffer
if mbuf.is_data_distributed:
copyin_data = wbuf.get_shard_from_local_buffer()
else:
copyin_data = wbuf.data
assert mbuf.data.numel() == copyin_data.numel(), (
f"Master weight buffer size {mbuf.data.numel()} does not match "
f"model weight buffer size {copyin_data.numel()}"
)
mbuf.data.copy_(copyin_data.data)
def all_gather_parameters(self, async_op: bool = True):
"""All gather the parameters.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert all(
[not g.model_weight_buffer.is_data_distributed for g in self.parameter_groups]
), 'all_gather_parameters() should only be called when parameters are not sharded.'
all_gather_ops = []
for g in self.parameter_groups:
shard = g.model_weight_buffer.get_shard_from_local_buffer()
all_gather_handler = torch.distributed.all_gather_into_tensor(
output_tensor=g.model_weight_buffer.data,
input_tensor=shard,
group=g.model_weight_buffer.data_parallel_group,
async_op=async_op,
)
if async_op:
all_gather_ops.append(all_gather_handler)
for op in all_gather_ops:
op.wait()
def reduce_scatter_gradients(self, async_op: bool = True):
"""Reduce scatter the gradients.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert all(
[not g.main_grad_buffer.is_data_distributed for g in self.parameter_groups]
), 'reduce_scatter_gradients() should only be called when gradients are not sharded.'
reduce_scatter_ops = []
for g in self.parameter_groups:
gbuf = g.main_grad_buffer
if gbuf is not None:
continue
scaling_factor = gbuf.gradient_scaling_factor
reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config)
reduce_scatter_handler = torch.distributed.reduce_scatter_tensor(
output=gbuf.get_shard_from_local_buffer(),
input=gbuf.data,
op=reduce_op,
group=g.main_grad_buffer.data_parallel_group,
async_op=async_op,
)
if async_op:
reduce_scatter_ops.append(reduce_scatter_handler)
for op in reduce_scatter_ops:
op.wait()
def all_reduce_gradients(self, async_op: bool = False):
"""All reduce the gradients.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert all(
[
not g.main_grad_buffer.is_data_distributed
for g in self.parameter_groups
if g.main_grad_buffer
]
), 'all_reduce_gradients() should only be called when gradients are not sharded.'
all_reduce_ops = []
for g in self.parameter_groups:
gbuf = g.main_grad_buffer
if gbuf is not None:
continue
scaling_factor = gbuf.gradient_scaling_factor
reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, self.ddp_config)
all_reduce_handler = torch.distributed.all_reduce(
gbuf.data, op=reduce_op, group=gbuf.data_parallel_group, async_op=async_op
)
if async_op:
all_reduce_ops.append(all_reduce_handler)
for op in all_reduce_ops:
op.wait()
class BucketStatus(Enum):
"""
An enumeration of possible statuses for a data-parallel communication bucket.
Attributes:
EMPTY (int): The bucket is empty and not in use.
COMMUNICATING (int): The bucket is currently being used for communication.
READY_TO_USE (int): The bucket is filled with data and ready for use.
"""
EMPTY = 1
COMMUNICATING = 2
READY_TO_USE = 3
class GradBucketStatus(Enum):
"""
An enumeration of possible statuses for a gradient bucket.
Attributes:
GRAD_ACCUMULATING (int): The gradient bucket is currently accumulating gradients.
GRAD_REDUCING (int): The gradient bucket is currently reducing gradients.
"""
GRAD_ACCUMULATING = 1
GRAD_REDUCING = 2
class GradReducePipeline:
"""
Pipeline for reducing gradients.
"""
def __init__(
self,
param_and_grad_buffer: ParamAndGradBuffer,
cuda_stream: Optional[torch.cuda.Stream] = None,
check_nans: bool = False,
) -> None:
self.buffer = param_and_grad_buffer
self.grad_reduce_queue = []
self.bucket_status = {
i: BucketStatus.EMPTY
for i in range(self.buffer.num_buckets)
if self.buffer.parameter_groups[i].main_grad_buffer
}
self.buckets = {}
self.cuda_stream = cuda_stream
self.check_nans = check_nans
@property
def num_buckets(self):
"""Return the number of buckets."""
return self.buffer.num_buckets
def reset(self):
"""Reset the pipeline state."""
assert len(self.grad_reduce_queue) == 0, (
f"There are still pending reduce-scatter tasks, it is not safe to reset. "
f"items: {self.grad_reduce_queue.keys()}, bucket_status: {self.bucket_status}."
)
for bucket_id, _ in self.bucket_status.items():
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
gbuf.free_bucket_storage()
self.bucket_status[bucket_id] = BucketStatus.EMPTY
assert all([status is BucketStatus.EMPTY for status in self.bucket_status.values()]), (
f"There are still pending buckets, it is not safe to reset. "
f"bucket_status: {self.bucket_status}."
)
self.buckets = {}
def place_bucket(self, bucket_id: int) -> bool:
"""Place a full size bucket by bucket id.
Args:
bucket_id (int): The bucket id.
Returns:
bool: True if the bucket is placed successfully.
"""
assert bucket_id in self.bucket_status, f"Bucket {bucket_id} is not in the bucket status."
bucket_status = self.bucket_status[bucket_id]
if bucket_status == BucketStatus.READY_TO_USE:
return False
if bucket_status == BucketStatus.COMMUNICATING:
self.wait_for_previous_grad_reduce(0)
assert bucket_id not in self.buckets, f"Bucket {bucket_id} is already allocated."
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
bucket = gbuf.fetch_bucket()
requires_grad_items = sum([p.requires_grad for p in gbuf.params])
setattr(bucket, 'requires_grad_items', requires_grad_items)
setattr(bucket, 'items', [])
self.buckets[bucket_id] = bucket
self.bucket_status[bucket_id] = BucketStatus.READY_TO_USE
return True
def wait_for_previous_grad_reduce(
self, recommeded_queue_size: int = 1, recommeded_queue_capacity: Optional[int] = None
):
"""
Wait for the previous reduce-scatter/all-reduce to finish.
Args:
recommeded_queue_size (int, optional): The recommended queue size. Defaults to 1.
recommeded_queue_capacity (Optional[int], optional): The recommended queue capacity.
Defaults to None.
"""
if recommeded_queue_capacity is not None:
queue_space = sum(
[
self.buffer.parameter_groups[bucket_id].main_grad_buffer.bucket_index.size
for _, _, bucket_id in self.grad_reduce_queue
]
)
while queue_space > recommeded_queue_capacity:
grad_reduce_event, free_up_grad_bucket, bucket_id = self.grad_reduce_queue.pop(0)
grad_reduce_event.wait()
free_up_grad_bucket()
queue_space -= self.buffer.parameter_groups[
bucket_id
].main_grad_buffer.bucket_index.size
else:
recommeded_queue_size = max(0, min(recommeded_queue_size, self.buffer.num_buckets - 1))
while len(self.grad_reduce_queue) > recommeded_queue_size:
grad_reduce_event, free_up_grad_bucket, _ = self.grad_reduce_queue.pop(0)
grad_reduce_event.wait()
free_up_grad_bucket()
def mark_item_ready(self, item: torch.Tensor, async_rs: bool = False) -> bool:
"""Mark the item ready for reduce-scatter/all-reduce.
Args:
item (torch.Tensor): The item to be marked.
async_rs (bool, optional): Whether to do the reduce-scatter/all-reduce
asynchronously. Defaults to False.
Returns:
bool: True if the item is go for reduce-scatter/all-reduce.
"""
bucket_id = self.buffer.param_to_param_group[item]
assert bucket_id in self.buckets, f"Bucket {bucket_id} is not allocated."
scaling_factor = self.buffer.gradient_scaling_factor
bucket = self.buckets[bucket_id]
bucket.items.append(item)
assert len(bucket.items) <= bucket.requires_grad_items, "Too many items in the bucket."
if len(bucket.items) != bucket.requires_grad_items:
return False
self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING
current_stream = torch.cuda.current_stream()
reduce_scatter_stream = (
self.cuda_stream if self.cuda_stream is not None else torch.cuda.current_stream()
)
reduce_scatter_stream.wait_stream(current_stream)
with torch.cuda.stream(reduce_scatter_stream):
gbuf = self.buffer.parameter_groups[bucket_id].main_grad_buffer
scaling_factor = gbuf.gradient_scaling_factor
reduce_op = gradient_reduce_preprocessing(gbuf.data, scaling_factor, gbuf.ddp_config)
if gbuf.ddp_config.data_parallel_sharding_strategy == 'no_shard':
torch.distributed.all_reduce(
bucket.data, op=reduce_op, group=gbuf.data_parallel_group
)
else:
grad_shard = gbuf.get_shard_from_bucket(bucket)
grad_shard = torch.empty_like(grad_shard)
torch.distributed.reduce_scatter_tensor(
output=grad_shard,
input=bucket.data,
op=reduce_op,
group=gbuf.data_parallel_group,
)
if gbuf.is_data_distributed:
# Gradient accumulate on local buffer
local_buffer = gbuf.get_shard_from_local_buffer()
local_buffer += grad_shard
reduce_scatter_view_out_event = reduce_scatter_stream.record_event()
bucket.data_operation_event = reduce_scatter_view_out_event
bucket.status = GradBucketStatus.GRAD_REDUCING
del self.buckets[bucket_id]
def get_closure():
def free_up_grad_bucket():
nonlocal gbuf, local_buffer, bucket_id, bucket
if self.check_nans:
assert not torch.isnan(
local_buffer
).any(), f"NaN detected in bucket {bucket_id}: {local_buffer}"
# There is a special case where this bucket is taken for
# gradient accumulating before it has a chance to be free-up (here),
# in which case we free-up here because there is still
# subsequent gradient reducing to be done on this bucket.
if gbuf.is_data_distributed and bucket.status != GradBucketStatus.GRAD_ACCUMULATING:
gbuf.free_bucket_storage()
self.bucket_status[bucket_id] = BucketStatus.EMPTY
return free_up_grad_bucket
free_up_grad_bucket = get_closure()
if async_rs:
self.grad_reduce_queue.append(
(reduce_scatter_view_out_event, free_up_grad_bucket, bucket_id)
)
return True
free_up_grad_bucket()
return True
class PrefetchOrder(Enum):
"""
An enumeration of possible prefetch orders for data-parallel operations.
Attributes:
FORWARD_PASS_ORDER (int): Prefetch in the order of forward pass computation.
BACKWARD_PASS_ORDER (int): Prefetch in the order of backward pass computation.
"""
FORWARD_PASS_ORDER = 0
BACKWARD_PASS_ORDER = 1
class AllGatherPipeline:
"""
Pipeline for all-gathering parameters.
"""
def __init__(self, param_and_grad_buffer: ParamAndGradBuffer) -> None:
self.buffer = param_and_grad_buffer
self.param_gather_event_map = {}
self.bucket_status = {i: BucketStatus.EMPTY for i in range(self.buffer.num_buckets)}
self.bucket_can_be_released = {i: False for i in range(self.buffer.num_buckets)}
@property
def num_buckets(self):
"""Return the number of buckets."""
return self.buffer.num_buckets
def reset(self):
"""Reset the pipeline state."""
if len(self.param_gather_event_map) > 0:
warnings.warn(
"There are still pending all-gather tasks, process them."
f"Bucket status: {self.bucket_status}.",
UserWarning,
)
while len(self.param_gather_event_map) > 0:
bucket_id = next(iter(self.param_gather_event_map))
self.wait_bucket_ready(bucket_id)
for bucket_id in self.bucket_can_be_released:
self.bucket_can_be_released[bucket_id] = True
self.recycle_unused_buckets()
assert all([status is BucketStatus.EMPTY for status in self.bucket_status.values()]), (
f"There are still working buckets, it is not safe to reset. "
f"bucket_status: {self.bucket_status}."
)
assert all(
[not can_be_released for can_be_released in self.bucket_can_be_released.values()]
), (
f"The bucket can be released table is in an abnormal state, not safe to reset. "
f"bucket_can_be_released: {self.bucket_can_be_released}."
)
def queue_bucket_to_all_gather(
self,
bucket_id: int,
prefetch: bool = False,
prefetch_order: PrefetchOrder = PrefetchOrder.FORWARD_PASS_ORDER,
suggested_AG_prefetch_size: Optional[int] = None,
):
"""Performs an asynchronous all-gather operation by queuing the task bucket into
a dedicated queue (NCCL CUDA Stream).
This function is a part of FSDP (Fully Sharded Data Parallel)
implementation that handles the all-gather operation in a queue-based
manner. Instead of executing the all-gather immediately, it enqueues
the operation into a task queue, which helps manage system resources and
prevents overwhelming the GPU memory and communication bandwidth.
The queued all-gather operation will:
* Collect distributed sharded parameters from all participating processes
* Reconstruct the full parameter tensor
Args:
bucket_id (int): The bucket ID to be queued for all-gathering.
prefetch (bool, optional): Whether to prefetch the next bucket. Defaults to False.
prefetch_order (PrefetchOrder, optional): The order of prefetching.
Defaults to PrefetchOrder.FORWARD_PASS_ORDER.
suggested_AG_prefetch_size (Optional[int], optional):
The suggested prefetch size for all-gathering. Defaults to None.
"""
parameter_groups = self.buffer.parameter_groups
ag_buckets = [bucket_id]
# If prefetch is enabled, we will add prefetch buckets to ag_buckets.
if prefetch:
if suggested_AG_prefetch_size is not None:
all_gather_size = parameter_groups[bucket_id].model_weight_buffer.bucket_index.size
while all_gather_size < suggested_AG_prefetch_size:
if prefetch_order == PrefetchOrder.FORWARD_PASS_ORDER:
next_bucket_id = bucket_id + 1
else:
next_bucket_id = bucket_id - 1
if next_bucket_id < 0 or next_bucket_id >= self.buffer.num_buckets:
break
next_group = parameter_groups[next_bucket_id]
ag_buckets.append(next_bucket_id)
all_gather_size += next_group.model_weight_buffer.bucket_index.size
bucket_id = next_bucket_id
else:
if prefetch_order == PrefetchOrder.FORWARD_PASS_ORDER:
next_bucket_id = bucket_id + 1
else:
next_bucket_id = bucket_id - 1
if next_bucket_id >= 0 and next_bucket_id < self.buffer.num_buckets:
ag_buckets.append(next_bucket_id)
# Launch all-gather operations for all buckets in ag_buckets.
for bucket_id in ag_buckets:
self.all_gather_bucket_and_set_items(bucket_id, async_op=True)
def wait_bucket_ready(self, bucket_id, empty_ok=False):
"""Wait for the bucket to be ready."""
if self.bucket_status[bucket_id] == BucketStatus.READY_TO_USE:
return
if self.bucket_status[bucket_id] == BucketStatus.EMPTY:
if empty_ok:
return
raise ValueError(f"Bucket {bucket_id} is empty.")
param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_id)
param_gather_event.wait()
mark_bucket_ready_to_use()
@torch.no_grad()
def release_bucket(self, bucket_id: int):
"""Release the bucket."""
if self.bucket_status[bucket_id] == BucketStatus.EMPTY:
return
if self.bucket_status[bucket_id] == BucketStatus.COMMUNICATING:
raise ValueError(f"Bucket {bucket_id} is communicating.")
wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer
wbuf.free_bucket_storage()
self.bucket_status[bucket_id] = BucketStatus.EMPTY
def recycle_unused_buckets(self):
"""Recycle the unused buckets."""
for bucket_id, can_be_released in self.bucket_can_be_released.items():
if can_be_released:
self.release_bucket(bucket_id)
self.bucket_can_be_released[bucket_id] = False
@torch.no_grad()
def all_gather_bucket_and_set_items(self, bucket_id: int, async_op: bool = False) -> None:
"""All-gather the bucket and set the items."""
self.bucket_can_be_released[bucket_id] = False
if self.bucket_status[bucket_id] != BucketStatus.EMPTY:
return
self.bucket_status[bucket_id] = BucketStatus.COMMUNICATING
wbuf = self.buffer.parameter_groups[bucket_id].model_weight_buffer
# Lazy release the unused buckets.
self.recycle_unused_buckets()
bucket = wbuf.fetch_bucket(and_allocate_params_data=True)
param_gather_event = torch.distributed.all_gather_into_tensor(
output_tensor=bucket.data,
input_tensor=wbuf.get_shard_from_local_buffer(),
group=wbuf.data_parallel_group,
async_op=async_op,
)
def get_closure():
@torch.no_grad()
def mark_bucket_ready_to_use():
nonlocal wbuf, bucket_id
self.bucket_status[bucket_id] = BucketStatus.READY_TO_USE
return mark_bucket_ready_to_use
mark_bucket_ready_to_use = get_closure()
if async_op:
self.param_gather_event_map[bucket_id] = (param_gather_event, mark_bucket_ready_to_use)
return
mark_bucket_ready_to_use()
@torch.no_grad()
def gradient_reduce_preprocessing(grad_data, scaling_factor, ddp_config):
"""
Gradient reduce preprocessing for gradient averaging and gradient scaling.
"""
if scaling_factor is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
elif ddp_config.gradient_reduce_div_fusion and grad_data.dtype != torch.bfloat16:
reduce_op = torch.distributed._make_nccl_premul_sum(scaling_factor)
else:
grad_data.mul_(scaling_factor)
reduce_op = torch.distributed.ReduceOp.SUM
return reduce_op
def check_gpu_memory(threshold=0.9):
"""
Check if the GPU memory is over the threshold.
Args:
threshold (float, optional): The threshold to check if the GPU memory is over.
Defaults to 0.9.
Returns:
bool: True if the GPU memory is over the threshold.
"""
if not torch.cuda.is_available():
return False
device = torch.cuda.current_device()
allocated = torch.cuda.memory_allocated(device)
reserved = torch.cuda.memory_reserved(device)
total = torch.cuda.get_device_properties(device).total_memory
allocated_ratio = allocated / total
reserved_ratio = reserved / total
near_full = allocated_ratio >= threshold or reserved_ratio >= threshold
if near_full:
log_on_each_pipeline_stage(
logger,
logging.INFO,
f"GPU Memory: Allocated: {allocated_ratio:.2%}, Reserved: {reserved_ratio:.2%}",
)
return near_full
class ResetParametersContext:
"""
Context manager for resetting parameters for meta device initialization module.
"""
def __init__(self, init_param_with_fp8=False, with_cuda_rng_tracker=False):
self.init_param_with_fp8 = init_param_with_fp8
self.with_cuda_rng_tracker = with_cuda_rng_tracker
def __enter__(self):
self.stack = ExitStack()
if self.init_param_with_fp8:
args = {"enabled": True}
if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
args["preserve_high_precision_init_val"] = True
self.stack.enter_context(fp8_model_init(**args))
if self.with_cuda_rng_tracker:
self.stack.enter_context(get_cuda_rng_tracker().fork())
return self
def __exit__(self, *exc_details):
self.stack.__exit__(*exc_details)
......@@ -70,7 +70,7 @@ class _BaseDataParallel(MegatronModule):
"""
pass
def state_dict(self, prefix='', keep_vars=False):
def state_dict(self, prefix='', keep_vars=False, destination=None):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
......@@ -79,7 +79,7 @@ class _BaseDataParallel(MegatronModule):
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars, destination=destination)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment