Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
This diff is collapsed.
...@@ -28,9 +28,10 @@ async_calls = AsyncCallsQueue() ...@@ -28,9 +28,10 @@ async_calls = AsyncCallsQueue()
def get_default_strategy(action: StrategyAction, backend: str, version: int): def get_default_strategy(action: StrategyAction, backend: str, version: int):
"""Retrieves a default strategy for a given action, backend and version.""" """Retrieves a default strategy for a given action, backend and version."""
error_hint: str = None
try: try:
if backend == 'zarr': if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages' error_hint = ' Please install `zarr` and `tensorstore!=0.1.46` packages'
from .tensorstore import register_default_tensorstore_strategies from .tensorstore import register_default_tensorstore_strategies
register_default_tensorstore_strategies() register_default_tensorstore_strategies()
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import os
from typing import Union
from torch.distributed.checkpoint import FileSystemReader, Metadata
class CachedMetadataFileSystemReader(FileSystemReader):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def __init__(self, path: Union[str, os.PathLike]) -> None:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super().__init__(path=path)
self._cached_metadata = None
def read_metadata(self) -> Metadata:
"""
Read metadata from file system, caching for subsequent calls.
Returns:
Metadata: Checkpoint metadata.
"""
if self._cached_metadata is None:
self._cached_metadata = super().read_metadata()
return self._cached_metadata
...@@ -69,7 +69,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): ...@@ -69,7 +69,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
""" """
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try: try:
return torch.load(load_path, map_location='cpu') return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e: except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist' err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()] ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
...@@ -95,12 +95,12 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): ...@@ -95,12 +95,12 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
sh_obj.data = None sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try: try:
loaded_obj = torch.load(load_path) loaded_obj = torch.load(load_path, weights_only=False)
except FileNotFoundError as e: except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect # Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt') old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try: try:
loaded_obj = torch.load(old_load_path) loaded_obj = torch.load(old_load_path, weights_only=False)
except FileNotFoundError: except FileNotFoundError:
err_msg = f'Object shard {load_path} not found' err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key obj_subdir = checkpoint_dir / sh_obj.key
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """ """ Storage writer for PyT Distributed format allowing asynchronous save. """
import gc import dataclasses
import logging import logging
import os import os
import queue import queue
from contextlib import contextmanager from functools import partial
from heapq import heappop, heappush
from itertools import chain from itertools import chain
from operator import itemgetter
from pathlib import Path from pathlib import Path
from time import time from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
...@@ -20,6 +22,8 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte ...@@ -20,6 +22,8 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte
from torch.distributed.checkpoint.storage import WriteResult from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future from torch.futures import Future
from .async_utils import _disable_gc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
...@@ -35,19 +39,6 @@ def _get_write_results_queue(): ...@@ -35,19 +39,6 @@ def _get_write_results_queue():
return _results_queue return _results_queue
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class FileSystemWriterAsync(FileSystemWriter): class FileSystemWriterAsync(FileSystemWriter):
""" """
Async-enabled implementation of FileSystemWriter using file IO. Async-enabled implementation of FileSystemWriter using file IO.
...@@ -76,6 +67,8 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -76,6 +67,8 @@ class FileSystemWriterAsync(FileSystemWriter):
'single_file_per_rank flag not supported for FileSystemWriterAsync' 'single_file_per_rank flag not supported for FileSystemWriterAsync'
) )
self.can_run_decentralized_global_plan: bool = True
# Intermediate state between preparation and finalization # Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None self.results_queue: Optional[mp.Queue] = None
...@@ -99,7 +92,7 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -99,7 +92,7 @@ class FileSystemWriterAsync(FileSystemWriter):
self.thread_count > 1 self.thread_count > 1
), "thread_count must be at least 2 if separation_hint is provided" ), "thread_count must be at least 2 if separation_hint is provided"
bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count
item_buckets = _split_by_size_and_type(bins, plan.items, self.separation_hint) item_buckets = _split_by_size_and_type(bins, plan.items)
logger.debug(f"bucket_prep, time: {time() - start}") logger.debug(f"bucket_prep, time: {time() - start}")
start = time() start = time()
...@@ -113,6 +106,23 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -113,6 +106,23 @@ class FileSystemWriterAsync(FileSystemWriter):
file_count += 1 file_count += 1
return file_name return file_name
def _clone_if_needed(ten: torch.Tensor):
"""Clone if we detect incontiguous storage for CPU tensors
Makes sure we perform a `clone` only if we detect incontiguous storage,
so that we don't blow up host memory unnecessarily.
TODO: For persistent worker, this work should be changed to move the cpu tensor
to shared_memory.
"""
ten = ten.detach()
if ten.device.type != "cpu":
# We do D2H later when the async_request is scheduled for both sync / async
# checkpointing
return ten
is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize
return ten.clone() if is_view else ten
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process # Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = [] self.write_buckets = []
for group_name, group_buckets in _split_by_separation_hint( for group_name, group_buckets in _split_by_separation_hint(
...@@ -125,7 +135,7 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -125,7 +135,7 @@ class FileSystemWriterAsync(FileSystemWriter):
if item.type == WriteItemType.BYTE_IO if item.type == WriteItemType.BYTE_IO
] ]
tensor_data = [ tensor_data = [
(item, planner.resolve_data(item).detach().to("cpu", non_blocking=True)) (item, _clone_if_needed(planner.resolve_data(item)))
for item in bucket for item in bucket
if item.type != WriteItemType.BYTE_IO if item.type != WriteItemType.BYTE_IO
] ]
...@@ -147,23 +157,49 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -147,23 +157,49 @@ class FileSystemWriterAsync(FileSystemWriter):
end = time() end = time()
logger.debug(f"D2H and push, time: {end - start}") logger.debug(f"D2H and push, time: {end - start}")
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]: def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]:
""" """
Get function that saves the data to storage along with its arguments. Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously. Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of: Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data 1) the function that saves the data.
- arguments to that function 2) the function that stages the GPU tensors to a destination for async checkpointing.
This function should be self-contained.
3) arguments to that function in 1).
""" """
if not self.write_buckets: if not self.write_buckets:
return None, () return None, None, ()
return (self.write_preloaded_data_multiproc, (self.write_buckets, self.results_queue)) return (
self.write_preloaded_data_multiproc,
partial(self.preload_tensors, self.write_buckets, True),
[torch.distributed.get_rank(), self.write_buckets, self.results_queue],
)
@staticmethod
def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]:
"""Preload tensors in state_dict to host memory through CPU memory
Args:
write_buckets(List): List of `WriteBucket`,
which includes what to be saved in a checkpoint
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result = []
for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
if non_blocking:
torch.cuda.synchronize()
return result
@staticmethod @staticmethod
@_disable_gc() @_disable_gc()
def write_preloaded_data_multiproc( def write_preloaded_data_multiproc(
write_buckets: List[WriteBucket], global_results_queue: mp.Queue rank, write_buckets: List[WriteBucket], global_results_queue: mp.Queue
) -> None: ) -> None:
""" """
Performs saving data to storage with multiple processes. Performs saving data to storage with multiple processes.
...@@ -186,6 +222,7 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -186,6 +222,7 @@ class FileSystemWriterAsync(FileSystemWriter):
(or an Exception) from parallel write processes to the main training process (or an Exception) from parallel write processes to the main training process
Returns: None Returns: None
""" """
logger = logging.getLogger(__name__)
w_start = time() w_start = time()
write_results_or_exc: Union[dict, Exception] = dict() write_results_or_exc: Union[dict, Exception] = dict()
ctx = mp.get_context('fork') ctx = mp.get_context('fork')
...@@ -234,20 +271,16 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -234,20 +271,16 @@ class FileSystemWriterAsync(FileSystemWriter):
logger.error(err_msg) logger.error(err_msg)
write_results_or_exc = local_results_or_exc write_results_or_exc = local_results_or_exc
break break
else: assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
assert isinstance(local_results_or_exc, list), type(local_results_or_exc) write_results_or_exc[local_proc_idx] = local_results_or_exc
write_results_or_exc[local_proc_idx] = local_results_or_exc p_list[local_proc_idx].join()
p_list[local_proc_idx].join()
logger.debug('FileSystemWriterAsync: collected worker results successfully') logger.debug('FileSystemWriterAsync: collected worker results successfully')
global_results_queue.put(write_results_or_exc) global_results_queue.put(write_results_or_exc)
w_end = time() w_end = time()
logger.debug( logger.debug(f"{w_end}, rank: {rank}," f" write(sync,parallel): {w_end - w_start}")
f"{w_end}, rank: {torch.distributed.get_rank()},"
f" write(sync,parallel): {w_end - w_start}"
)
@staticmethod @staticmethod
@_disable_gc() @_disable_gc()
...@@ -271,6 +304,8 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -271,6 +304,8 @@ class FileSystemWriterAsync(FileSystemWriter):
Returns: None, the write result are put into the `queue` Returns: None, the write result are put into the `queue`
""" """
logger = logging.getLogger(__name__)
logger.debug(f'{local_proc_idx} started')
mem_before = _process_memory() mem_before = _process_memory()
local_results = [] local_results = []
...@@ -288,6 +323,7 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -288,6 +323,7 @@ class FileSystemWriterAsync(FileSystemWriter):
os.fsync(stream.fileno()) os.fsync(stream.fileno())
local_output = (local_proc_idx, local_results) local_output = (local_proc_idx, local_results)
except Exception as e: except Exception as e:
logger.debug(f'{local_proc_idx} failed')
local_output = (local_proc_idx, e) local_output = (local_proc_idx, e)
results_queue.put(local_output) results_queue.put(local_output)
...@@ -334,10 +370,23 @@ class FileSystemWriterAsync(FileSystemWriter): ...@@ -334,10 +370,23 @@ class FileSystemWriterAsync(FileSystemWriter):
) )
return list(chain.from_iterable(write_results.values())) return list(chain.from_iterable(write_results.values()))
def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Instead of assigning indices by plan order, uses PyT rank (same outcome).
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
def _split_by_size_and_type( Returns:
bins: int, items: List[WriteItem], separation_hint: Optional[str] = None SavePlan - locally transformed plan equivalent to the plan that would be
) -> List[List[WriteItem]]: created by the coordinator
"""
return dataclasses.replace(
local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_")
)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
""" """
Splits write items according to item size into close to uniform bins. Splits write items according to item size into close to uniform bins.
...@@ -353,24 +402,32 @@ def _split_by_size_and_type( ...@@ -353,24 +402,32 @@ def _split_by_size_and_type(
if bins == 1: if bins == 1:
return [items] return [items]
bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] bytes_items: List[WriteItem] = []
tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] tensor_items: List[WriteItem] = []
for wi in items:
container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items
container.append(wi)
buckets: List[List[WriteItem]] = [[] for _ in range(bins)] buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)] bucket_sizes = [0 for _ in range(bins)]
tensor_items.sort(key=_item_size, reverse=True)
# Assign bytes with a simple round-robin # Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items): for i, item in enumerate(bytes_items):
buckets[i % bins].append(item) buckets[i % bins].append(item)
# Then, assign tensors according to their sizes # Sort tensor items by size in decreasing order once and store the size with item
for item in tensor_items: sized_tensors = [(item, _item_size(item)) for item in tensor_items]
# TODO replace with headq sized_tensors.sort(key=itemgetter(1), reverse=True)
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item) # Use a min heap for bin assignment
bucket_sizes[idx] += _item_size(item) # Store (total_size_of_bin, bin_index) tuples
heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)]
# Assign tensors using heap
for item, size in sized_tensors:
total_bin_size, bin_idx = heappop(heap)
buckets[bin_idx].append(item)
heappush(heap, (total_bin_size + size, bin_idx))
return buckets return buckets
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
import logging import logging
from pathlib import Path from pathlib import Path
from time import time from time import time
from typing import Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.checkpoint import Metadata
from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing import ShardedObject, ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import ( from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace, dict_list_map_inplace,
...@@ -19,6 +20,7 @@ from megatron.core.dist_checkpointing.exchange_utils import ( ...@@ -19,6 +20,7 @@ from megatron.core.dist_checkpointing.exchange_utils import (
ShardDistribution, ShardDistribution,
determine_main_replica_uniform_distribution, determine_main_replica_uniform_distribution,
exchange_by_distribution, exchange_by_distribution,
exchange_loaded_objects_gather_object,
) )
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica
from megatron.core.dist_checkpointing.strategies.base import ( from megatron.core.dist_checkpointing.strategies.base import (
...@@ -26,7 +28,12 @@ from megatron.core.dist_checkpointing.strategies.base import ( ...@@ -26,7 +28,12 @@ from megatron.core.dist_checkpointing.strategies.base import (
LoadShardedStrategy, LoadShardedStrategy,
SaveShardedStrategy, SaveShardedStrategy,
) )
from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId from megatron.core.dist_checkpointing.utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
_ShardId,
debug_time,
)
from megatron.core.dist_checkpointing.validation import ( from megatron.core.dist_checkpointing.validation import (
determine_global_metadata, determine_global_metadata,
validate_sharding_integrity, validate_sharding_integrity,
...@@ -34,6 +41,8 @@ from megatron.core.dist_checkpointing.validation import ( ...@@ -34,6 +41,8 @@ from megatron.core.dist_checkpointing.validation import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar('T', ShardedObject, ShardedTensor)
class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy): class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
"""Wraps arbitrary strategy and distributes the save during `save`. """Wraps arbitrary strategy and distributes the save during `save`.
...@@ -170,7 +179,9 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): ...@@ -170,7 +179,9 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
self.exchange_algo = exchange_algo self.exchange_algo = exchange_algo
self.cached_distribution: Optional[ShardDistribution] = None self.cached_distribution: Optional[ShardDistribution] = None
self.cached_global_metadata: Optional[Metadata] = None
@debug_time("FullyParallelLoadStrategyWrapper.load", logger)
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Distributes the load and calls underlying strategy only for parts of the state dict. """Distributes the load and calls underlying strategy only for parts of the state dict.
...@@ -200,18 +211,20 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): ...@@ -200,18 +211,20 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
a state dict that would be loaded with the underlying strategy a state dict that would be loaded with the underlying strategy
without this wrapper. without this wrapper.
""" """
loaded_state_dict = {}
if torch.distributed.get_world_size(self.parallelization_group) <= 1: if torch.distributed.get_world_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir) return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
# Step 1 and 2: exchange load metadata and distribute the load # Step 1 and 2: exchange load metadata and distribute the load
start = time() with debug_time("self.apply_loading_parallelization", logger):
precomputed_distribution = self.apply_loading_parallelization(sharded_state_dict) precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization(
assert ( sharded_state_dict
precomputed_distribution is not None )
), 'Expecting non-trivial distribution for non-trivial parallelization group' assert (
end = time() precomputed_distribution is not None
logger.debug(f'self.apply_loading_parallelization took {end - start}s') ), 'Expecting non-trivial distribution for non-trivial parallelization group'
start = end
# Step 3: load part of the checkpoint. # Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately # Load only sharded objects first. ShardedTensors will be loaded separately
...@@ -219,88 +232,121 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): ...@@ -219,88 +232,121 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
(sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = ( (sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = (
self._defer_loading_sharded_tensors(sharded_state_dict) self._defer_loading_sharded_tensors(sharded_state_dict)
) )
loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir)
end = time() (sharded_objects, sharded_state_dict, to_load_objects, unloaded_objects) = (
logger.debug(f'Base load of ShardedObjects took {end - start}s') self._defer_loading_sharded_objects(sharded_state_dict)
start = end )
# Load sharded tensors separately assert (
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir) len(sharded_state_dict) == 0
), "sharded_state_dict is not empty after deferring tensors and objects"
with debug_time("base_load_ShardedObjects", logger):
# Load sharded objects first
loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir)
with debug_time("base_load_ShardedTensors", logger):
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
with debug_time("self.exchange_loaded_tensors", logger):
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
)
end = time() with debug_time("torch.cuda.synchronize", logger):
logger.debug(f'Base load of ShardedTensors took {end - start}s') torch.cuda.synchronize()
start = end
all_loaded_objects = exchange_loaded_objects_gather_object(loaded_objects)
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}') if not set(unloaded_objects.keys()).issubset(all_loaded_objects.keys()):
all_loaded_tensors = exchange_by_distribution( missing_object_shards = set(unloaded_objects.keys()) - all_loaded_objects.keys()
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException( raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}' f'Missing object shards after fully parallel loading: {missing_object_shards}'
) )
sync_start = time()
torch.cuda.synchronize() torch.cuda.synchronize()
end = time()
logger.debug(f'torch.cuda.synchronize took {end - sync_start}s')
logger.debug(f'self.exchange_loaded_tensors took {end - start}s')
self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors) self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
self.fill_in_deferred_sharded_objects(sharded_objects, all_loaded_objects)
merge(loaded_state_dict, sharded_objects)
merge(loaded_state_dict, sharded_tensors) merge(loaded_state_dict, sharded_tensors)
if hasattr(self.base_strategy, "cached_global_metadata"):
self.cached_global_metadata = self.base_strategy.cached_global_metadata
return loaded_state_dict return loaded_state_dict
@staticmethod
def _defer_loading_sharded_objects(
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedObject],
Dict[_ShardId, ShardedObject],
]:
return _defer_loading_sharded_items(sharded_state_dict, ShardedObject, _sharded_object_id)
@staticmethod
def _defer_loading_sharded_tensors( def _defer_loading_sharded_tensors(
self, sharded_state_dict: ShardedStateDict sharded_state_dict: ShardedStateDict,
) -> Tuple[ ) -> Tuple[
ShardedStateDict, ShardedStateDict,
ShardedStateDict, ShardedStateDict,
Dict[_ShardId, ShardedTensor], Dict[_ShardId, ShardedTensor],
Dict[_ShardId, ShardedTensor], Dict[_ShardId, ShardedTensor],
]: ]:
"""Divides state dict into parts loaded by this vs other ranks. return _defer_loading_sharded_items(
sharded_state_dict, ShardedTensor, _sharded_tensor_shard_id
)
ShardedTensors with main replica_id will be loaded by this rank, @staticmethod
others will be received by other ranks (after loading from storage). def fill_in_deferred_sharded_objects(
sharded_state_dict: ShardedStateDict, loaded_objects: Dict[_ShardId, Any]
) -> None:
"""Fill in objects not loaded by current rank with objects from `loaded_objects` map.
Args: Args:
sharded_state_dict (ShardedStateDict): state dict with ShardedTensor sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
that will be divided. ShardedObjects are completely replaced with corresponding objects.
loaded_objects (Dict[_ShardId, Any]): dict allowing to map
Returns: a tuple of: ShardedObject from the sharded_state_dict to loaded objects.
- ShardedStateDict: sub-state dict only with ShardedTensors
- ShardedStateDict: sub-state dict with non-ShardedTensors
- Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified
by shard ids. This is a mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *this* rank
- Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_tensors, sharded_state_dict = extract_matching_values( Returns:
sharded_state_dict, lambda v: isinstance(v, ShardedTensor) None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_objects, ShardedObject, _sharded_object_id
) )
def wrap_non_main_replicas(x): @staticmethod
if isinstance(x, ShardedTensor): def fill_in_deferred_sharded_tensors(
# Assign shard to be loaded or not sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
if is_main_replica(x.replica_id): ) -> None:
to_load_shards[_sharded_tensor_shard_id(x)] = x """Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
else:
unloaded_shards[_sharded_tensor_shard_id(x)] = x Args:
return x sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
dict_list_map_inplace(wrap_non_main_replicas, sharded_tensors) Returns:
return sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_tensors, ShardedTensor, _sharded_tensor_shard_id
)
def apply_loading_parallelization( def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict self, sharded_state_dict: ShardedStateDict
...@@ -339,34 +385,6 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy): ...@@ -339,34 +385,6 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
return precomputed_distribution return precomputed_distribution
def fill_in_deferred_sharded_tensors(
self, sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
"""
def fill_in_sharded_tensor(x):
if isinstance(x, ShardedTensor):
try:
x = loaded_tensors[_sharded_tensor_shard_id(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded tensor shard: {_sharded_tensor_shard_id(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_tensor, sharded_state_dict)
@property @property
def can_handle_sharded_objects(self): def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects return self.base_strategy.can_handle_sharded_objects
...@@ -437,3 +455,61 @@ def distribute_main_replicas_with_precomputed_distribution( ...@@ -437,3 +455,61 @@ def distribute_main_replicas_with_precomputed_distribution(
sh_ten.replica_id = 0 sh_ten.replica_id = 0
else: else:
sh_ten.replica_id = 1 sh_ten.replica_id = 1
def _defer_loading_sharded_items(
sharded_state_dict: ShardedStateDict, item_type: type, shard_id_func: Callable[[T], _ShardId]
) -> Tuple[ShardedStateDict, ShardedStateDict, Dict[_ShardId, T], Dict[_ShardId, T]]:
"""Divides state dict into parts loaded by this vs other ranks.
Args:
sharded_state_dict (ShardedStateDict): state dict with sharded items
that will be divided.
item_type: The type of sharded item (ShardedObject or ShardedTensor)
shard_id_func: Function to get the shard ID for the item type
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with sharded items
- ShardedStateDict: sub-state dict with non-sharded items
- Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank
- Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_items, remaining_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, item_type)
)
def wrap_non_main_replicas(x: Any) -> Any:
if isinstance(x, item_type):
shard_id = shard_id_func(x)
if is_main_replica(x.replica_id):
to_load_shards[shard_id] = x
else:
unloaded_shards[shard_id] = x
return x
dict_list_map_inplace(wrap_non_main_replicas, sharded_items)
return sharded_items, remaining_state_dict, to_load_shards, unloaded_shards
def _fill_in_deferred_sharded_items(
sharded_state_dict: ShardedStateDict,
loaded_items: Dict[_ShardId, Any],
item_type: type,
shard_id_func: Callable[[T], _ShardId],
) -> None:
"""Helper function to fill in items not loaded by current rank."""
def fill_in_sharded_item(x: Any) -> Any:
if isinstance(x, item_type):
try:
x = loaded_items[shard_id_func(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded item shard: {shard_id_func(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_item, sharded_state_dict)
...@@ -13,7 +13,7 @@ import logging ...@@ -13,7 +13,7 @@ import logging
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,6 @@ from megatron.core.dist_checkpointing.dict_utils import ( ...@@ -27,7 +27,6 @@ from megatron.core.dist_checkpointing.dict_utils import (
extract_matching_values, extract_matching_values,
) )
from megatron.core.dist_checkpointing.mapping import ( from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict, ShardedStateDict,
ShardedTensorFactory, ShardedTensorFactory,
StateDict, StateDict,
...@@ -84,11 +83,7 @@ def is_nd_flattened_tensor(sh_ten: Any) -> bool: ...@@ -84,11 +83,7 @@ def is_nd_flattened_tensor(sh_ten: Any) -> bool:
Returns: Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1) bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
""" """
return ( return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None
isinstance(sh_ten, ShardedTensor)
and sh_ten.flattened_range is not None
and len(sh_ten.global_shape) > 1
)
# information needed to restore. With current implementation, this is a nested state dict # information needed to restore. With current implementation, this is a nested state dict
...@@ -132,8 +127,12 @@ def apply_nd_flattened_tensors_reformulation( ...@@ -132,8 +127,12 @@ def apply_nd_flattened_tensors_reformulation(
try: try:
sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key] sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key]
except KeyError as e: except KeyError as e:
# Handle legacy checkpointing where 1-D flatten tensor metadata was not saved
if len(sh_ten.global_shape) == 1:
return sh_ten
raise CheckpointingException( raise CheckpointingException(
f'Missing reformulation metadata for tensor {sh_ten}. Existing keys: {reformulation_metadata.keys()}' f'Missing reformulation metadata for tensor {sh_ten}. '
f'Existing keys: {reformulation_metadata.keys()}'
) from e ) from e
ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape
...@@ -235,13 +234,16 @@ def reformulate_single_nd_flattened_tensor( ...@@ -235,13 +234,16 @@ def reformulate_single_nd_flattened_tensor(
): ):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units # without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset) first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset)
# `math.ceil` argument is an exact offset of the app next shard expressed in ckpt_local_shape units # `math.ceil` argument is an exact offset of the app next shard expressed
# in ckpt_local_shape units
next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1)) next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1))
overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset)) overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset))
logger.debug( logger.debug(
f'Generated the following number of overlap shards for each dimension: {list(map(len, overlap_dim_offsets))}' f'Generated the following number of overlap shards for each dimension: '
f' for fragmentation ckpt {ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} and chunk offset {sh_ten.local_chunk_offset_in_global()}' f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt '
f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} '
f'and chunk offset {sh_ten.local_chunk_offset_in_global()}'
) )
reformulated_sh_tens = {} reformulated_sh_tens = {}
for chunk_offset in product(*overlap_dim_offsets): for chunk_offset in product(*overlap_dim_offsets):
...@@ -286,7 +288,8 @@ def reformulate_single_nd_flattened_tensor( ...@@ -286,7 +288,8 @@ def reformulate_single_nd_flattened_tensor(
# For each ckpt shard, we fill the appropriate application shard part # For each ckpt shard, we fill the appropriate application shard part
dest_ten = app_non_flat_ten dest_ten = app_non_flat_ten
src_ten = ckpt_ten.view(ckpt_local_shape) src_ten = ckpt_ten.view(ckpt_local_shape)
# We don't need narrowing over `prepend_axis_num` axes so we take the [sh_ten.prepend_axis_num:] offsets slice # We don't need narrowing over `prepend_axis_num` axes so we take
# the [sh_ten.prepend_axis_num:] offsets slice
for ( for (
dim, dim,
offset_for_saved_tensor, offset_for_saved_tensor,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from logging import getLogger from logging import getLogger
from time import time from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -16,19 +16,37 @@ from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict ...@@ -16,19 +16,37 @@ from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING: if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync from .filesystem_async import FileSystemWriterAsync
from .torch import MCoreSavePlanner
logger = getLogger(__name__) logger = getLogger(__name__)
from dataclasses import fields
def _compare_dataclasses(obj1, obj2):
if type(obj1) != type(obj2):
return f"Objects are of different types: {type(obj1)} and {type(obj2)}"
differences = []
for field in fields(obj1):
value1 = getattr(obj1, field.name)
value2 = getattr(obj2, field.name)
if value1 != value2:
differences.append(f"{field.name}: {value1} != {value2}")
return differences if differences else "All fields are equal"
def save_state_dict_async_plan( def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE, state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync', storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0, coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None, planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None,
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None, cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]: loaded_all_plans: Optional[List[SavePlan]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]:
""" """
First stage of saving a state dict to storage. First stage of saving a state dict to storage.
...@@ -62,7 +80,7 @@ def save_state_dict_async_plan( ...@@ -62,7 +80,7 @@ def save_state_dict_async_plan(
Returns: Tuple of: Returns: Tuple of:
- storage writer (the one passed as input) - storage writer (the one passed as input)
- metadata from planning - metadata from planning (or None if we reuse cached global metadata)
- distributed wrapper used for planning - distributed wrapper used for planning
The return value of this function should be passed as an input to The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning. `save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
...@@ -80,6 +98,7 @@ def save_state_dict_async_plan( ...@@ -80,6 +98,7 @@ def save_state_dict_async_plan(
global_metadata = None global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save") logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan local_plan = cached_local_plan
global_md_verify_reuse = False
def local_step(): def local_step():
nonlocal local_plan nonlocal local_plan
...@@ -101,11 +120,34 @@ def save_state_dict_async_plan( ...@@ -101,11 +120,34 @@ def save_state_dict_async_plan(
return all_local_plans return all_local_plans
# Execute local and global planning # Execute local and global planning
# Ideally we want to use the cached plan. Otherwise if the planner and storage_writer
# allow it (`can_run_decentralized_global_plan`) we gather the plans to create
# the metadata but prepare the plans independently on each rank.
# In the worst case we have to reduce_scatter all the plans.
start_plan = time() start_plan = time()
if validated_cache_reuse and cached_central_plan: if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable") logger.debug(f"rank: {rank}, Passed cache reusable")
local_step() local_step()
central_plan = cached_central_plan central_plan = cached_central_plan
elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr(
storage_writer, 'can_run_decentralized_global_plan', False
):
local_plan = local_step()
global_md_verify_reuse = verify_global_md_reuse(
loaded_all_plans, local_plan, rank, dist_wrapper
)
if not loaded_all_plans or not global_md_verify_reuse:
all_local_plans = dist_wrapper.gather_object(local_plan)
if dist_wrapper.is_coordinator:
_, global_metadata = planner.create_global_plan(all_local_plans)
global_metadata.all_local_plans = all_local_plans
else:
logger.debug(f"rank: {rank}, Passed cached global metadata")
global_metadata = None
local_plan = planner.create_decentralized_global_plan(local_plan)
local_plan = storage_writer.prepare_decentralized_global_plan(local_plan)
central_plan = local_plan
else: else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step) central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan) central_plan = planner.finish_plan(central_plan)
...@@ -118,13 +160,56 @@ def save_state_dict_async_plan( ...@@ -118,13 +160,56 @@ def save_state_dict_async_plan(
end = time() end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}") logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return ( return (
(storage_writer, cast(Metadata, global_metadata), dist_wrapper), (storage_writer, global_metadata, dist_wrapper),
central_plan, central_plan,
local_plan, local_plan,
cached_central_plan == central_plan, cached_central_plan == central_plan,
global_md_verify_reuse,
) )
def verify_global_md_reuse(
loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper
) -> bool:
"""
Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
Args:
loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint
(stored in checkpoint metadata).
local_plan: SavePlan, The local save plan.
rank: Current process rank.
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: True iff the global metadata reuse is possible.
"""
logger.debug(f"verifying reuse of global metadata")
if not loaded_all_plans:
global_md_verify_reuse = False
logger.debug("loaded global metadata reuse verification: no loaded plans passed")
elif len(loaded_all_plans) == dist_wrapper.get_world_size():
local_verify_reuse = all(
getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name)
for f in fields(local_plan)
if f.name != 'storage_data'
)
if not local_verify_reuse:
logger.debug(
f"local_verify_reuse is False: diffs -"
f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}"
)
all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda')
torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN)
# Check if all reduced results are True
global_md_verify_reuse = all_results.item() == 1
else:
global_md_verify_reuse = False
return global_md_verify_reuse
def save_state_dict_async_finalize( def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
) -> None: ) -> None:
......
File mode changed from 100755 to 100644
...@@ -55,6 +55,7 @@ from .base import ( ...@@ -55,6 +55,7 @@ from .base import (
StrategyAction, StrategyAction,
register_default_strategy, register_default_strategy,
) )
from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
from .filesystem_async import FileSystemWriterAsync from .filesystem_async import FileSystemWriterAsync
from .resharding import ( from .resharding import (
TensorReformulationMetadata, TensorReformulationMetadata,
...@@ -126,7 +127,9 @@ def flatten_state_dict( ...@@ -126,7 +127,9 @@ def flatten_state_dict(
def sharded_tensor_to_torch_sharded_tensor( def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None sh_tens: List[ShardedTensor],
rank: Optional[int] = None,
load_legacy_1d_flatten_tensors: bool = False,
) -> TorchShardedTensor: ) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks. """Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
...@@ -138,13 +141,12 @@ def sharded_tensor_to_torch_sharded_tensor( ...@@ -138,13 +141,12 @@ def sharded_tensor_to_torch_sharded_tensor(
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor. NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute. The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors: This function handles 2 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`) 1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`) 2. N-D flattened ShardedTensors (`has_flattened_range`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape. (1) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency: Type (2) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z] we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices. partitioned according to `flattened_range` slices.
...@@ -154,6 +156,8 @@ def sharded_tensor_to_torch_sharded_tensor( ...@@ -154,6 +156,8 @@ def sharded_tensor_to_torch_sharded_tensor(
sh_tens (List[ShardedTensor]): list of sharded tensors to convert sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor. rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg. If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors
should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards. Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
...@@ -163,41 +167,21 @@ def sharded_tensor_to_torch_sharded_tensor( ...@@ -163,41 +167,21 @@ def sharded_tensor_to_torch_sharded_tensor(
some_sh_ten = sh_tens[0] some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None has_flattened_range = some_sh_ten.flattened_range is not None
is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1
for sh_ten in sh_tens: for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous(): if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous() sh_ten.data = sh_ten.data.contiguous()
if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1:
# Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors
has_flattened_range = False
local_global_offsets = {} local_global_offsets = {}
prepend_axis_num = sh_tens[0].prepend_axis_num prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs) # Determine local shards according to tensor type (see docs)
if is_flattened_range_1d: if has_flattened_range:
# Type (2) case: 1D flattened ShardedTensors
for sh_ten in sh_tens:
assert len(sh_ten.global_offset) == 1, sh_ten
assert sh_ten.prepend_axis_num == 0, sh_ten
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
global_shape = some_sh_ten.global_shape
offsets_shape = (
some_sh_ten.local_shape
) # local shape is not flattened, we need it for chunk offsets
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
[
sh_ten.global_offset[0] + sh_ten.flattened_range.start
], # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
elif has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors # Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens: for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append( local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
...@@ -250,10 +234,7 @@ def sharded_tensor_to_torch_sharded_tensor( ...@@ -250,10 +234,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# local shard # local shard
placement = f"rank:{rank}/cuda" placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]: for sh_ten in local_global_offsets[offset]:
if is_flattened_range_1d: if has_flattened_range:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
elif has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global() assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard # This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check # This is needed for a PyT Dist internal integrity check
...@@ -270,7 +251,7 @@ def sharded_tensor_to_torch_sharded_tensor( ...@@ -270,7 +251,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size. # Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS. # The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda" placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range and not is_flattened_range_1d: if has_flattened_range:
offset = offset + (0,) offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:] size = (1,) * len(offsets_shape) + global_shape[-1:]
else: else:
...@@ -296,7 +277,7 @@ def sharded_tensor_to_torch_sharded_tensor( ...@@ -296,7 +277,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# This won't be stored in the checkpoint, only for runtime purposes # This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data() pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {} pyt_sh_ten.mcore_metadata = {}
if has_flattened_range and not is_flattened_range_1d: if has_flattened_range:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten return pyt_sh_ten
...@@ -305,6 +286,7 @@ def mcore_to_pyt_state_dict( ...@@ -305,6 +286,7 @@ def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]], state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False, is_loading: bool = False,
init_device: torch.device = torch.device("cpu"), init_device: torch.device = torch.device("cpu"),
load_legacy_1d_flatten_tensors: bool = False,
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]: ) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Convert state dict with ShardedTensors and ShardedObjects """Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format. to state dict compatible with PyT Dist format.
...@@ -348,7 +330,9 @@ def mcore_to_pyt_state_dict( ...@@ -348,7 +330,9 @@ def mcore_to_pyt_state_dict(
if sh_ten.allow_shape_mismatch and is_loading: if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_() sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank) torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
sh_tens, rank, load_legacy_1d_flatten_tensors
)
torch_sh_ten.key = sh_tens[0].key torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten return torch_sh_ten
...@@ -460,6 +444,7 @@ class MCoreSavePlanner(DefaultSavePlanner): ...@@ -460,6 +444,7 @@ class MCoreSavePlanner(DefaultSavePlanner):
*args, *args,
dedup_replicated_tensors: Optional[bool] = None, dedup_replicated_tensors: Optional[bool] = None,
nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None, nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None,
can_run_decentralized_global_plan: bool = True,
**kwargs, **kwargs,
) -> None: ) -> None:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
...@@ -468,6 +453,14 @@ class MCoreSavePlanner(DefaultSavePlanner): ...@@ -468,6 +453,14 @@ class MCoreSavePlanner(DefaultSavePlanner):
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {} self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
self.can_run_decentralized_global_plan = can_run_decentralized_global_plan
if can_run_decentralized_global_plan:
assert (
not dedup_replicated_tensors
), 'Cannot run decentralized plan with dedup_replicated_tensors=True'
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
def create_local_plan(self) -> SavePlan: def create_local_plan(self) -> SavePlan:
"""Adds IOBytes write request on non-coordinator ranks.""" """Adds IOBytes write request on non-coordinator ranks."""
...@@ -503,6 +496,23 @@ class MCoreSavePlanner(DefaultSavePlanner): ...@@ -503,6 +496,23 @@ class MCoreSavePlanner(DefaultSavePlanner):
metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans))) metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans)))
return global_plan, metadata return global_plan, metadata
def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Nothing to do, just some checks.
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan'
return local_plan
def transform_object(self, write_item: WriteItem, object: Any): def transform_object(self, write_item: WriteItem, object: Any):
"""Make no transformations - bytes objects are already serialized.""" """Make no transformations - bytes objects are already serialized."""
return object return object
...@@ -535,6 +545,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): ...@@ -535,6 +545,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
else: else:
expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten) expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if loaded_shape != expected_shape: if loaded_shape != expected_shape:
if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1:
# Handle legacy 1-D flattened tensors checkpoint format
# where the global shape is not stored in the metadata
expected_shape = sh_ten.global_shape
if loaded_shape == expected_shape:
continue
_msg = ( _msg = (
f'Global shape mismatch for loaded ({loaded_shape})' f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor' f' and expected ({expected_shape}) tensor'
...@@ -634,6 +650,8 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): ...@@ -634,6 +650,8 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
self.separation_hint = separation_hint self.separation_hint = separation_hint
self.validated_loaded_metadata_reuse = False
def async_save( def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest: ) -> AsyncRequest:
...@@ -663,7 +681,14 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): ...@@ -663,7 +681,14 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused # (return None) so `self.cached_global_metadata` is reused
args_cached_plans = None args_cached_plans = None
loaded_all_plans = None
if self.use_cached_ckpt_structure: if self.use_cached_ckpt_structure:
loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None)
if loaded_all_plans is None:
logger.debug(
"no all_local_plans in metadata - can't verify global metadata reuse..."
)
args_cached_plans = ( args_cached_plans = (
self.cached_central_plan, self.cached_central_plan,
self.cached_local_plan, self.cached_local_plan,
...@@ -675,24 +700,44 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): ...@@ -675,24 +700,44 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
self.cached_central_plan, self.cached_central_plan,
self.cached_local_plan, self.cached_local_plan,
self.validated_cache_reuse, self.validated_cache_reuse,
self.validated_loaded_metadata_reuse,
) = save_state_dict_async_plan( ) = save_state_dict_async_plan(
pyt_state_dict, pyt_state_dict,
writer, writer,
None, None,
coordinator, coordinator,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica), planner=MCoreSavePlanner(
dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False
),
cached_ckpt_structure=args_cached_plans, cached_ckpt_structure=args_cached_plans,
loaded_all_plans=loaded_all_plans,
) )
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if self.use_cached_ckpt_structure: if self.use_cached_ckpt_structure:
if self.validated_cache_reuse: if (
loaded_all_plans
and self.cached_global_metadata
and self.validated_loaded_metadata_reuse
):
if coordinator == rank:
logger.debug(
f"rank: {rank}, reuse global metadata from loaded"
f" .metadata, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
elif self.validated_cache_reuse:
logger.debug(f"rank: {rank}, cache validated") logger.debug(f"rank: {rank}, cache validated")
if save_state_dict_ret[1]: # when global_metadata is not cached if save_state_dict_ret[1]: # when global_metadata is not cached
self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata
# Only Coordinator rank holds cached global_metadata # Only Coordinator rank holds cached global_metadata
# (None is returned for global_metadata) # (None is returned for global_metadata)
elif coordinator == rank: elif coordinator == rank:
logger.debug(f"rank: {rank}, reuse metadata, {save_state_dict_ret[1]}") logger.debug(
f"rank: {rank}, reuse global metadata cached from previous"
f" save iteration, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret) save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata save_state_dict_ret[1] = self.cached_global_metadata
...@@ -700,13 +745,13 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy): ...@@ -700,13 +745,13 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest: def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest:
save_fn_args = writer.get_save_function_and_args() save_fn_args = writer.get_save_function_and_args()
save_fn, save_args = save_fn_args save_fn, preload_fn, save_args = save_fn_args
def finalize_fn(): def finalize_fn():
save_state_dict_async_finalize(*save_state_dict_ret) save_state_dict_async_finalize(*save_state_dict_ret)
torch.distributed.barrier() torch.distributed.barrier()
return AsyncRequest(save_fn, save_args, [finalize_fn]) return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
def can_handle_sharded_objects(self): def can_handle_sharded_objects(self):
return True return True
...@@ -736,6 +781,12 @@ def get_reformulation_metadata( ...@@ -736,6 +781,12 @@ def get_reformulation_metadata(
'nd_reformulated_orig_global_shape' 'nd_reformulated_orig_global_shape'
] ]
except KeyError as e: except KeyError as e:
if len(sh_ten.global_shape) == 1:
warnings.warn(
f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. '
'Skip metadata reformulation.'
)
continue
raise CheckpointingException( raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} ' f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}' f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
...@@ -750,6 +801,10 @@ def get_reformulation_metadata( ...@@ -750,6 +801,10 @@ def get_reformulation_metadata(
class TorchDistLoadShardedStrategy(LoadShardedStrategy): class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format.""" """Basic load strategy for the PyT Distributed format."""
def __init__(self):
self.cached_global_metadata: Optional[Metadata] = None
super().__init__()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
...@@ -761,10 +816,18 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): ...@@ -761,10 +816,18 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
Returns: loaded state dict Returns: loaded state dict
""" """
# Apply N-D tensors resharding # Apply N-D tensors resharding
reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation( sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir) sharded_state_dict, reformulation_metadata
) )
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors = False
for sh_ten in nested_values(sharded_state_dict):
if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata:
has_legacy_1d_flattened_tensors = True
break
flexible_shape_sharded_tensors = [ flexible_shape_sharded_tensors = [
sh_ten sh_ten
for sh_ten in nested_values(sharded_state_dict) for sh_ten in nested_values(sharded_state_dict)
...@@ -776,15 +839,23 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): ...@@ -776,15 +839,23 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
(sharded_state_dict, flat_mapping, rename_mapping) = ( (sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict) _replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
) )
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True) pyt_state_dict = mcore_to_pyt_state_dict(
sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format # Load PyT Distributed format
fsr = CachedMetadataFileSystemReader(checkpoint_dir)
checkpoint.load_state_dict( checkpoint.load_state_dict(
pyt_state_dict, pyt_state_dict,
FileSystemReader(checkpoint_dir), fsr,
planner=MCoreLoadPlanner( planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
), ),
) )
self.cached_global_metadata = (
fsr.read_metadata()
) # no storage interaction thanks to caching
pyt_state_dict = cast( pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
) )
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """ """ 2-stage checkpoint loading. """
import os
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, wraps from functools import partial, wraps
from itertools import chain from itertools import chain
from logging import DEBUG, INFO, StreamHandler, getLogger from logging import getLogger
from operator import attrgetter, itemgetter from operator import attrgetter, itemgetter
from pathlib import Path from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy from .base import LoadShardedStrategy
from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array from .tensorstore import _load_from_array, open_ts_array
from .zarr import flatten_range, load_zarr_based_sharded_metadata from .zarr import flatten_range, load_zarr_based_sharded_metadata
_import_trigger = None _import_trigger = None
...@@ -26,9 +25,16 @@ _import_trigger = None ...@@ -26,9 +25,16 @@ _import_trigger = None
timers = defaultdict(list) timers = defaultdict(list)
logger = getLogger(__name__) logger = getLogger(__name__)
logger.warning(
'megatron.core.dist_checkpointing.two_stage module is deprecated'
' and will be removed in Megatron-Core v0.12. Please use'
' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.'
)
def timed(verbose=True): def timed(verbose=True):
"""Timing decorator."""
def timed_dec(fn): def timed_dec(fn):
name = fn.__name__ name = fn.__name__
...@@ -59,6 +65,7 @@ class _ShardedTensorMetadata: ...@@ -59,6 +65,7 @@ class _ShardedTensorMetadata:
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
"""Id of a sharded tensor."""
return (sharded_tensor.key, sharded_tensor.global_offset) return (sharded_tensor.key, sharded_tensor.global_offset)
...@@ -101,6 +108,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): ...@@ -101,6 +108,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
self.global_rank = torch.distributed.get_rank() self.global_rank = torch.distributed.get_rank()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Main load method."""
self.maybe_init_gloo_group() self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict) all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
...@@ -109,6 +117,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): ...@@ -109,6 +117,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
return sharded_state_dict return sharded_state_dict
def summarize_load_times(self): def summarize_load_times(self):
"""Summarize load times."""
torch.distributed.barrier() torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:') logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs # TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
...@@ -124,6 +133,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): ...@@ -124,6 +133,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed(verbose=False) @timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
"""Load tensor from storage."""
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array( ret = _load_from_array(
ten_meta.sharded_tensor_no_data, ten_meta.sharded_tensor_no_data,
...@@ -136,12 +146,15 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): ...@@ -136,12 +146,15 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed() @timed()
def maybe_init_gloo_group(self): def maybe_init_gloo_group(self):
"""Create Gloo groups."""
if not self.cpu_transfer: if not self.cpu_transfer:
return return
all_groups = [None] * torch.distributed.get_world_size() all_groups = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_groups, self.dp_group_ranks) torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
all_groups = set(tuple(sorted(gr)) for gr in all_groups) all_groups = set(tuple(sorted(gr)) for gr in all_groups)
for group_ranks in sorted(all_groups): for group_ranks in sorted(all_groups):
# "two_stage" module will be deprecated, so not replace new_group()
# with ...parallel_state.create_group() func setting group_desc here.
gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo') gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
if self.global_rank in group_ranks: if self.global_rank in group_ranks:
self.data_parallel_group = gloo_pg self.data_parallel_group = gloo_pg
...@@ -211,7 +224,8 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy): ...@@ -211,7 +224,8 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
) )
logger.debug( logger.debug(
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}\
({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
) )
torch.distributed.broadcast( torch.distributed.broadcast(
exchange_tensor, group=self.data_parallel_group, src=src_rank exchange_tensor, group=self.data_parallel_group, src=src_rank
......
File mode changed from 100755 to 100644
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace, merge, nested_values
from .exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
)
from .mapping import ShardedObject, ShardedStateDict, ShardedTensor, StateDict, apply_factory_merges
from .state_dict_utils import load_preprocess, save_preprocess
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
debug_time,
extract_sharded_base,
zip_strict,
)
from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
@dataclass
class MCoreTensorAwareStateDict(TensorAwareStateDict):
"""
MCore-specific class defining the interface between the MCore state dict and checkpoint manager.
This class distinguishes between raw objects, the common state dict, and sharded state dicts
(tensor parts). It also handles optional metadata needed for fully parallel save/load.
"""
common: StateDict
sharded_state_dict: ShardedStateDict
_is_hollow: bool = False
@staticmethod
def _validate_params(algo):
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
@staticmethod
def _get_distribution(
fully_parallel, sharded_part, parallelization_group, cached_distribution=None
):
if fully_parallel:
if cached_distribution is None:
distribution = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
logger.debug(f'MCore_TASD._get_distribution calculated distribution')
else:
distribution = cached_distribution
logger.debug(f'MCore_TASD._get_distribution used cache')
else:
distribution = (None, None, None, None)
logger.debug(f'MCore_TASD._get_distribution returned empty distribution')
return distribution
@staticmethod
def _remove_redundant_data(
fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group
):
if fully_parallel:
for sh_base in nested_values(sharded_part):
# TODO remove redundant objects as well
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_to_saving_rank[shard_id] != torch.distributed.get_rank(
group=parallelization_group
):
sh_base.data = None
@classmethod
@debug_time("from_state_dict", logger)
def from_state_dict(
cls,
sharded_state_dict: ShardedStateDict,
algo: str = 'fully_parallel',
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
cached_metadata: ShardDistribution = None,
) -> Tuple[TensorAwareStateDict, ShardDistribution]:
"""
Constructs a TensorAwareStateDict from a sharded state dictionary.
This method preprocesses the input `sharded_state_dict`, validates parameters,
and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`.
Args:
sharded_state_dict: The input sharded state dictionary to be converted.
algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'.
- 'fully_parallel' enables fully parallel initialization.
parallelization_group (Optional): A distributed process group for parallelization.
cached_metadata (Optional): Precomputed metadata from previous saves.
- Reuses data that doesn't need recalculation, optimizing the creation process.
Returns:
TensorAwareStateDict: An instance initialized with the provided sharded state dictionary
and optional cached metadata.
- The metadata is stored in memory to speed up future saves.
"""
with debug_time("_get_distribution", logger):
cls._validate_params(algo)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(
sharded_state_dict, cached_metadata is None
)
cacheable_distribution = cls._get_distribution(
fully_parallel, sharded_part, parallelization_group, cached_metadata
)
if cacheable_distribution is not None:
shard_to_saving_rank, _, _, _ = cacheable_distribution
cls._remove_redundant_data(
fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group
)
return (
MCoreTensorAwareStateDict(common=common_state_dict, sharded_state_dict=sharded_part),
cacheable_distribution,
)
@property
def is_hollow(self):
"""
True iff tensors had been extracted and have not been inserted back yet.
"""
return self._is_hollow
@property
def _sharded_tensors(self):
# Three possible states for sharded_tensor:
# 1. sharded_tensor with data (.data = tensor)
# 2. sharded_tensor hollow (.data = None, .orig_device = orig_device)
# 3. removed sharded_tensor (.data = None, no device information)
# TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data
if self.is_hollow:
for sh_base in nested_values(self.sharded_state_dict):
# FIXME: Hacky way to store the original device of the popped tensor
if isinstance(sh_base, ShardedTensor) and hasattr(sh_base, 'orig_device'):
yield sh_base
else:
for sh_base in nested_values(self.sharded_state_dict):
if isinstance(sh_base, ShardedTensor) and sh_base.data is not None:
yield sh_base
@property
def tensors(self) -> Iterator[torch.Tensor]:
"""
Get the tensor data from the state dict.
"""
assert not self.is_hollow # TODO raise exception
return map(lambda sh_ten: sh_ten.data, self._sharded_tensors)
@property
def common_state_dict(self) -> Dict:
"""
Get the common state dict from the state dict.
"""
return self.common
def pop_tensors(self) -> List[torch.Tensor]:
"""
Extracts the tensor data from the wrapped state dict, preserving metadata.
Replaces the tensor data in sharded_tensors with device type of extracted tensors.
After this operation, the state dictionary is "hollow", containing no tensor data.
Further calls to `pop_tensor` will raise an error.
@return List of extracted tensors
"""
assert not self.is_hollow # TODO raise exception
result = []
for sh_ten in self._sharded_tensors:
result.append(sh_ten.data)
# FIXME: Hacky way to store the original device, which is not included in the metadata
setattr(sh_ten, 'orig_device', sh_ten.data.device.type)
sh_ten.data = None
self._is_hollow = True
return result
def insert_tensors(self, tensor_data: Iterable[torch.Tensor]):
"""
Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values
Value of `self` is considered to be the same after:
```
self.insert_tensors(self.pop_tensors())
```
"""
assert self.is_hollow # TODO raise exception
for sh_ten, ten in zip_strict(self._sharded_tensors, tensor_data):
# FIXME: Hacky way to store the original device
if sh_ten.orig_device == ten.device.type:
delattr(sh_ten, 'orig_device')
# Tensor might be on non-original device
sh_ten.data = ten
self._is_hollow = False
def init_tensors(self):
"""
Initializes empty tensors with the same properties as the original tensors.
This function should only be called after the original tensors have been popped.
It ensures that the newly created empty tensors match the shape,
dtype, and device of the originals, but contain no data.
"""
assert self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
# Hacky way to retrieve the original device
sh_ten.init_data(sh_ten.orig_device)
delattr(sh_ten, 'orig_device')
self._is_hollow = False
def copy_tensors_to_cpu(self, non_blocking=False):
"""
Stores CPU copies of tensors in the state_dict, replacing the originals,
but without destroying them.
The original devices are remembered for restoration with restore_tensor_device().
Using non_blocking=True allows for asynchronous copying.
"""
assert not self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
if sh_ten.data.device.type == 'cpu':
# Skip cloning if it's already confirmed to be a copy
if not hasattr(sh_ten, 'orig_device'):
sh_ten.data = sh_ten.data.clone()
else:
# FIXME: Hacky way to store the original device
if not hasattr(sh_ten, 'orig_device'):
setattr(sh_ten, 'orig_device', sh_ten.data.device.type)
sh_ten.data = sh_ten.data.detach().to("cpu", non_blocking=non_blocking)
def restore_tensor_device(self, non_blocking=True):
"""
Restores all tensors to their original devices, if a move is required.
Using non_blocking=True allows for asynchronous copying.
"""
assert not self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
# FIXME: Hacky way to store the original device
if hasattr(sh_ten, 'orig_device'):
sh_ten.data = sh_ten.data.to(sh_ten.orig_device, non_blocking=non_blocking)
delattr(sh_ten, 'orig_device')
def _insert_sharded_data(
self, fully_parallel, sharded_part, parallelization_group, exchange_algo
):
loaded_tensors = {}
for sh_ten in self._sharded_tensors:
loaded_tensors[_sharded_tensor_shard_id(sh_ten)] = sh_ten.data
if fully_parallel:
with debug_time("_get_distribution", logger):
distribution = self._get_distribution(
fully_parallel, sharded_part, parallelization_group
)
if distribution is not None:
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
# TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
with debug_time("exchange_by_distribution", logger):
loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
distribution,
parallelization_group,
exchange_algo,
)
torch.cuda.synchronize()
loaded_objects = {}
for sh_base in nested_values(self.sharded_state_dict):
if not isinstance(sh_base, ShardedTensor):
assert isinstance(sh_base, ShardedObject)
loaded_objects[_sharded_object_id(sh_base)] = sh_base.data
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
assert shard_id in loaded_tensors, (x, shard_id, loaded_tensors.keys())
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part)
@debug_time("to_state_dict", logger)
def to_state_dict(
self,
sharded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Convert tensor-aware dict back to the original state_dict
"""
with debug_time("load_preprocess_and_state_dict_manipulations", logger):
assert not self.is_hollow # TODO raise exception
self._validate_params(algo)
fully_parallel = algo == 'fully_parallel'
# __adding__ common part
recreated_state_dict = dict_list_map_outplace(lambda x: x, self.common)
if not sharded_state_dict:
return recreated_state_dict
# TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
with debug_time("validate_sharding_integrity", logger):
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
with debug_time("_insert_sharded_data", logger):
self._insert_sharded_data(
fully_parallel, sharded_part, parallelization_group, exchange_algo
)
with debug_time("apply_factory_merges", logger):
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """ """ Helpers for manipulating sharded tensors and sharded state dicts. """
import logging
from contextlib import contextmanager
from time import time
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values from .dict_utils import dict_list_map_inplace, extract_matching_values
...@@ -20,6 +22,18 @@ from .mapping import ( ...@@ -20,6 +22,18 @@ from .mapping import (
_ShardId = Tuple[str, tuple, Optional[tuple]] _ShardId = Tuple[str, tuple, Optional[tuple]]
def zip_strict(*args):
"""
Alternative to Python's builtin zip(..., strict=True) (available in 3.10+).
Apart from providing functionality in earlier versions of Python is also more verbose.
(Python's zip does not print lengths, only which iterable has finished earlier)
"""
args = [list(a) for a in args]
lens = [len(a) for a in args]
assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!"
return zip(*args)
def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId: def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId:
"""Unique id of the sharded tensor data. """Unique id of the sharded tensor data.
...@@ -217,3 +231,89 @@ def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[ ...@@ -217,3 +231,89 @@ def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[
return x return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict) dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
fallback_logger = logging.getLogger(__name__)
__LOGGER_NAME_STACK = []
__LOGGER_STACK = []
@contextmanager
def logger_stack(name: Optional[str] = None, current_logger: Optional[logging.Logger] = None):
"""Context manager for managing logger and name stack.
Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical
logging and contextual logger usage. Ensures the logger stack is restored afterward.
Args:
name (str, optional): Name to add to the logger stack. Defaults to None.
current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in
the stack or a fallback if none exist.
Yields:
Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and
the current logger for the block.
Example:
with logger_stack("scope", logger):
logger.info("Log within 'scope'")
"""
if name:
__LOGGER_NAME_STACK.append(name)
if current_logger:
__LOGGER_STACK.append(current_logger)
last_logger = current_logger
elif __LOGGER_STACK:
last_logger = __LOGGER_STACK[-1]
else:
last_logger = fallback_logger
try:
yield ".".join(__LOGGER_NAME_STACK), last_logger
finally:
if name and __LOGGER_NAME_STACK:
__LOGGER_NAME_STACK.pop(-1)
if current_logger and __LOGGER_STACK:
__LOGGER_STACK.pop(-1)
@contextmanager
def debug_time(
name: str, logger: Optional[logging.Logger] = None, threshold: float = float("-inf"), level=None
):
"""Simple context manager for timing functions/code blocks.
Args:
name (str): Label describing the code being measured.
logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger.
threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster.
level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset;
WARNING otherwise.
"""
with logger_stack(name, logger) as (stacked_name, last_logger):
start = time()
try:
yield
finally:
result = time() - start
if result < threshold:
return
if level is None:
level = logging.DEBUG if threshold == float("-inf") else logging.WARNING
last_logger.log(level, f"{stacked_name} took {result:.4f}s")
def debug_msg(msg: str):
"""Logs a debug message using the current logger stack.
This function formats and logs a debug message with the current logger
and name stack, preserving context from the logger_stack context manager.
Args:
msg (str): The message to be logged at the debug level.
Example:
debug_msg("Checkpoint initialized")
# Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name")
"""
with logger_stack(None, None) as (stacked_name, last_logger):
last_logger.debug(f"{stacked_name} {msg}")
...@@ -412,7 +412,7 @@ def validate_sharding_integrity( ...@@ -412,7 +412,7 @@ def validate_sharding_integrity(
CheckpointingException for invalid access pattern CheckpointingException for invalid access pattern
""" """
if common_state_dict: if common_state_dict is not None:
_validate_common_state_dict(common_state_dict) _validate_common_state_dict(common_state_dict)
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
...@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): ...@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
lambda x: x[1], lambda x: x[1],
_validate_sharding_for_key_flattened, _validate_sharding_for_key_flattened,
) )
else: # For each shard with at least 1 flattened tensor in it, the above
if not torch.all(shard_access_cnt == 1): # `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') # The only thing that can go wrong at this point is that some shard don't have
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') # *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
if not torch.all(shard_access_cnt == 1):
raise CheckpointingException(
f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
)
def _compute_shards_access(rank_sharding): def _compute_shards_access(rank_sharding):
...@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard): ...@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices))) starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if ( expected_size = np.product(local_shape)
starts[0] != 0 if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException( raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
) )
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .loss_func import loss_func from .fully_sharded_data_parallel import FullyShardedDataParallel
from .model_provider import model_provider
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment