Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
......@@ -4,17 +4,36 @@
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import gc
import logging
from abc import ABC, abstractmethod
from collections import deque
from time import time
from typing import Callable, List, NamedTuple, Optional, Tuple
from contextlib import contextmanager
from queue import Empty
from time import sleep, time
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
import torch
from torch import multiprocessing as mp
from ..utils import debug_time
logger = logging.getLogger(__name__)
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class AsyncRequest(NamedTuple):
"""Represents an async request that needs to be scheduled for execution.
......@@ -24,12 +43,22 @@ class AsyncRequest(NamedTuple):
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
async_fn_kwargs (Tuple): kwargs to pass to `async_fn`.
preload_fn (Callable): preload function to stage tensors from GPU to Host.
This should be self-contained with a proper list of arguments with `partial`.
is_frozen (Bool): a flag to indicate this async request can be modified or not.
call_idx (int): index variable used to order async requests for synchronization
in preloading and writing tensors on the async caller
"""
async_fn: Optional[Callable]
async_fn_args: Tuple
finalize_fns: List[Callable]
async_fn_kwargs: Dict = {}
preload_fn: Callable = None
is_frozen: bool = False
call_idx: int = 0
def add_finalize_fn(self, fn: Callable) -> None:
"""Adds a new finalize function to the request.
......@@ -66,7 +95,70 @@ class AsyncRequest(NamedTuple):
return self._replace(is_frozen=True)
class DistributedAsyncCaller:
class AsyncCaller(ABC):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
@abstractmethod
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Schedule `async_req` with some process forking or reusing
persistent worker
This method must be called on all ranks.
Args:
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
raise NotImplementedError("This should be implemented")
@abstractmethod
def is_current_async_call_done(self, blocking: bool, no_dist: bool) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
raise NotImplementedError("This should be implemented")
def sync_all_async_calls(self, is_alive: int) -> bool:
"""Check if all ranks have completed async checkpoint writing
Args:
is_alive (bool): if True, the current async request is not completed
Returns:
bool: True if all ranks are done, False if at least one rank is still active.
"""
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten)
return ten[0] == 0
@abstractmethod
def close(self):
"""Terminate the async caller at exit of an application or some termination conditions"""
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
def __del__(self):
self.close()
class TemporalAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
......@@ -76,7 +168,8 @@ class DistributedAsyncCaller:
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None
def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple) -> None:
@_disable_gc()
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
......@@ -84,27 +177,35 @@ class DistributedAsyncCaller:
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
save_args (Tuple): async function args.
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
if async_fn is None:
if async_req.async_fn is None:
return # nothing to do
async_fn_args = list(async_req.async_fn_args)
if async_req.preload_fn:
# If there's a preload_fn in `async_req`, we call this func
# to do the defined action in `async_req.preload_fn` to
# stage GPU tensors to its defined destination
async_fn_args[1] = async_req.preload_fn()
rank = torch.distributed.get_rank()
start_sync = time()
torch.cuda.synchronize()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {end_sync - start_sync} to finish D2H "
)
logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ")
ctx = mp.get_context('fork')
self.start_time = time()
self.process = ctx.Process(target=async_fn, args=save_args)
self.process = ctx.Process(
target=async_req.async_fn, args=async_fn_args, kwargs=async_req.async_fn_kwargs
)
self.process.start()
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} to schedule async ckpt "
)
logger.debug(f"rank: {rank}, takes {init_time - self.start_time} to schedule async ckpt ")
def is_current_async_call_done(self, blocking=False) -> bool:
def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
......@@ -114,31 +215,229 @@ class DistributedAsyncCaller:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead as torch.distributed.barrier (single integer all-reduce)
# The following takes the same overhead
# as torch.distributed.barrier (single integer all-reduce)
is_alive = int(self.process.is_alive()) if self.process is not None else 0
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
if not is_done and blocking:
self.close()
is_done = True
return is_done
def close(self):
if self.process:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
"TemporalAsyncCaller: Async process join finished "
f"after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
class PersistentAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def __init__(self):
self.process: mp.Process = None
self.start_time: Optional[float] = None
ctx = mp.get_context('spawn')
# main queue to deliver `AsyncRequest` from host to the ckpt worker
self.queue: mp.JoinableQueue = ctx.JoinableQueue()
# Queue used to synchronize for the completion of preloading tensors to host
# between a trainer and ckpt worker
self.preload_q: mp.JoinableQueue = ctx.JoinableQueue()
# Queue used to inform trainer when the saving is completed
self.comp_q: mp.Queue = ctx.Queue()
self.cur_item: int = None
self.cur_idx: int = -1
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Put `AsyncRequest` to the Persistent Async Caller
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
async_req (AsyncRequest): `AsyncRequest` object containing to
schedule a checkpointing request
"""
if async_req.async_fn is None:
return # nothing to do
start_sync = end_sync = None
self.start_time = time()
if self.process is None:
ctx = mp.get_context('spawn')
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Starting Async Caller"
)
self.process: mp.Process = ctx.Process(
target=PersistentAsyncCaller.async_loop,
args=(
torch.distributed.get_rank(),
self.queue,
self.preload_q,
self.comp_q,
logger.getEffectiveLevel(),
),
)
self.process.start()
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Started Async Caller"
)
if async_req.preload_fn:
self.preload_q.put(async_req.call_idx)
self.queue.put(async_req)
logger.debug(f"rank: {torch.distributed.get_rank()}, put {async_req.call_idx}")
if async_req.preload_fn:
start_sync = time()
# Synchronize for pre-staging tensors
self.preload_q.join()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, "
f"takes {end_sync - start_sync} to finish D2H "
)
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, DistributedAsyncCaller is_alive: {is_alive}"
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} "
"to schedule async ckpt "
)
torch.distributed.all_reduce(ten)
if ten[0] > 0 and not blocking:
return False
else:
if self.process is not None:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
return True
def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
is_alive: bool = False
if self.process:
while self.cur_item is None:
try:
# Retrieve comp call_idx without waiting
self.cur_item = self.comp_q.get_nowait()
except Empty:
# This method is called after any `AsyncRequest` is pushed to the main loop
# So, the background writing is still active
# before the worker put call_idx to `comp_q`
if not blocking:
is_alive = True
break
sleep(0.1)
if self.cur_item is not None:
logger.debug(
f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}"
f" is completed, {is_alive}"
)
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
# This is set to False when blocking == False so this routine is called again
# to simply call `sync_all_async_calls` to check if other ranks complete the writing
if is_done:
# The current request is completed globally. Reset the current item for polling.
logger.debug(
f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}"
f" is completed globally, {is_done}"
)
self.cur_item = None
return is_done
def close(self):
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
)
if self.process:
self.queue.put('DONE')
self.queue.join()
self.process.join()
self.process = None
@staticmethod
@_disable_gc()
def async_loop(
rank: int,
queue: mp.JoinableQueue,
preload_q: mp.JoinableQueue,
comp_q: mp.Queue,
log_level: int = logging.INFO,
):
"""Main function for the persistent checkpoint worker
The persisent worker is created once and terminated at exit or
when application calls `close()` explictily
This routine receives `AsyncRequest` and does `preload_fn` first and
put the integer value in `preload_q` to inform the trainer to proceed.
When the `async_fn` from the request` is completed (background saving is done),
it puts a integer value to `comp_q` to notify the trainer the completion.
Args:
rank (int): the rank of the trainer where the persistent worker is created.
queue (mp.JoinableQueue): the main queue used to receive `AsyncRequest
from the training rank
preload_q (mp.JoinableQueue): a queue to inform trainer that preloading of tensors
from GPU to Host or dedicated location is completed
comp_q (mp.Queue): a queue to inform the training rank the completion of scheduled
async checkpoint request
log_level (int, Optional): an integer to set log-level in this spawned process
to get aligned with the training rank's logging level
"""
logger = logging.getLogger(__name__)
logger.setLevel(log_level)
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has started")
while True:
item = queue.get()
if isinstance(item, str) and item == 'DONE':
queue.task_done()
break
elif isinstance(item, AsyncRequest):
async_fn_args = list(item.async_fn_args)
if item.preload_fn:
call_idx = preload_q.get()
# the 2nd arg is state dict
async_fn_args[1] = item.preload_fn()
logger.debug(f"{rank} has completed D2H of {call_idx}")
preload_q.task_done()
item.async_fn(*async_fn_args, **item.async_fn_kwargs)
logger.debug(f"{rank} has completed saving {item.call_idx}")
comp_q.put(item.call_idx)
queue.task_done()
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has terminated")
class _ActiveAsyncRequest(NamedTuple):
......@@ -152,7 +451,7 @@ class _ActiveAsyncRequest(NamedTuple):
"""
idx: int
async_caller: DistributedAsyncCaller
async_caller: AsyncCaller
async_request: AsyncRequest
......@@ -163,9 +462,18 @@ class AsyncCallsQueue:
active calls with `maybe_finalize_async_calls`.
"""
def __init__(self):
def __init__(self, persistent: bool = False):
self.async_calls: deque[_ActiveAsyncRequest] = deque([])
self.call_idx: int = -1
self.persistent: bool = persistent
self.persistent_caller: AsyncCaller = None
def _get_async_caller(self):
if not self.persistent:
return TemporalAsyncCaller()
if self.persistent_caller is None:
self.persistent_caller = PersistentAsyncCaller()
return self.persistent_caller
def schedule_async_request(self, async_request: AsyncRequest) -> int:
"""Start a new async call and add it to a queue of active async calls.
......@@ -180,13 +488,20 @@ class AsyncCallsQueue:
This can help the user keep track of the async calls.
"""
self.call_idx += 1
async_caller = DistributedAsyncCaller()
async_caller = self._get_async_caller()
# Backward compatibility for local checkpointing built with the old AsyncRequest
if len(async_request._fields) != len(AsyncRequest._fields):
async_request = AsyncRequest(**async_request._asdict())
async_request = async_request._replace(call_idx=self.call_idx)
finalize_fns = async_request.finalize_fns
async_request = async_request._replace(finalize_fns=None)
async_request = async_request.freeze()
async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
async_caller.schedule_async_call(async_request)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, finalize_fns))
return self.call_idx
def maybe_finalize_async_calls(self, blocking=False) -> List[int]:
def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
"""Finalizes all available calls.
This method must be called on all ranks.
......@@ -201,18 +516,20 @@ class AsyncCallsQueue:
"""
call_idx_finalized = []
while self.async_calls:
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking)
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(
blocking, no_dist
)
if not next_async_done:
break
call_idx, _, async_request = self.async_calls.popleft()
for finalize_fn in async_request.finalize_fns:
finalize_fn()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert (
ten.item() == call_idx
), 'Unmatched async calls. That probably means not all ranks are participating in async finalization'
call_idx_finalized.append(call_idx)
with debug_time("finalize", logger):
call_idx, _, finalize_fns = self.async_calls.popleft()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert ten.item() == call_idx, 'Unmatched async calls. '
'That probably means not all ranks are participating in async finalization'
for finalize_fn in finalize_fns:
finalize_fn()
call_idx_finalized.append(call_idx)
return call_idx_finalized
def get_num_unfinalized_calls(self):
......@@ -222,3 +539,5 @@ class AsyncCallsQueue:
def close(self):
"""Finalize all calls upon closing."""
self.maybe_finalize_async_calls(blocking=True)
if self.persistent and self.persistent_caller:
self.persistent_caller.close()
......@@ -28,9 +28,10 @@ async_calls = AsyncCallsQueue()
def get_default_strategy(action: StrategyAction, backend: str, version: int):
"""Retrieves a default strategy for a given action, backend and version."""
error_hint: str = None
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
error_hint = ' Please install `zarr` and `tensorstore!=0.1.46` packages'
from .tensorstore import register_default_tensorstore_strategies
register_default_tensorstore_strategies()
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import os
from typing import Union
from torch.distributed.checkpoint import FileSystemReader, Metadata
class CachedMetadataFileSystemReader(FileSystemReader):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def __init__(self, path: Union[str, os.PathLike]) -> None:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super().__init__(path=path)
self._cached_metadata = None
def read_metadata(self) -> Metadata:
"""
Read metadata from file system, caching for subsequent calls.
Returns:
Metadata: Checkpoint metadata.
"""
if self._cached_metadata is None:
self._cached_metadata = super().read_metadata()
return self._cached_metadata
......@@ -69,7 +69,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
......@@ -95,12 +95,12 @@ class TorchCommonLoadStrategy(LoadCommonStrategy):
sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try:
loaded_obj = torch.load(load_path)
loaded_obj = torch.load(load_path, weights_only=False)
except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(old_load_path)
loaded_obj = torch.load(old_load_path, weights_only=False)
except FileNotFoundError:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import gc
import dataclasses
import logging
import os
import queue
from contextlib import contextmanager
from functools import partial
from heapq import heappop, heappush
from itertools import chain
from operator import itemgetter
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
......@@ -20,6 +22,8 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
from .async_utils import _disable_gc
logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
......@@ -35,19 +39,6 @@ def _get_write_results_queue():
return _results_queue
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class FileSystemWriterAsync(FileSystemWriter):
"""
Async-enabled implementation of FileSystemWriter using file IO.
......@@ -76,6 +67,8 @@ class FileSystemWriterAsync(FileSystemWriter):
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
self.can_run_decentralized_global_plan: bool = True
# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
......@@ -99,7 +92,7 @@ class FileSystemWriterAsync(FileSystemWriter):
self.thread_count > 1
), "thread_count must be at least 2 if separation_hint is provided"
bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count
item_buckets = _split_by_size_and_type(bins, plan.items, self.separation_hint)
item_buckets = _split_by_size_and_type(bins, plan.items)
logger.debug(f"bucket_prep, time: {time() - start}")
start = time()
......@@ -113,6 +106,23 @@ class FileSystemWriterAsync(FileSystemWriter):
file_count += 1
return file_name
def _clone_if_needed(ten: torch.Tensor):
"""Clone if we detect incontiguous storage for CPU tensors
Makes sure we perform a `clone` only if we detect incontiguous storage,
so that we don't blow up host memory unnecessarily.
TODO: For persistent worker, this work should be changed to move the cpu tensor
to shared_memory.
"""
ten = ten.detach()
if ten.device.type != "cpu":
# We do D2H later when the async_request is scheduled for both sync / async
# checkpointing
return ten
is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize
return ten.clone() if is_view else ten
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = []
for group_name, group_buckets in _split_by_separation_hint(
......@@ -125,7 +135,7 @@ class FileSystemWriterAsync(FileSystemWriter):
if item.type == WriteItemType.BYTE_IO
]
tensor_data = [
(item, planner.resolve_data(item).detach().to("cpu", non_blocking=True))
(item, _clone_if_needed(planner.resolve_data(item)))
for item in bucket
if item.type != WriteItemType.BYTE_IO
]
......@@ -147,23 +157,49 @@ class FileSystemWriterAsync(FileSystemWriter):
end = time()
logger.debug(f"D2H and push, time: {end - start}")
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]:
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data
- arguments to that function
1) the function that saves the data.
2) the function that stages the GPU tensors to a destination for async checkpointing.
This function should be self-contained.
3) arguments to that function in 1).
"""
if not self.write_buckets:
return None, ()
return (self.write_preloaded_data_multiproc, (self.write_buckets, self.results_queue))
return None, None, ()
return (
self.write_preloaded_data_multiproc,
partial(self.preload_tensors, self.write_buckets, True),
[torch.distributed.get_rank(), self.write_buckets, self.results_queue],
)
@staticmethod
def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]:
"""Preload tensors in state_dict to host memory through CPU memory
Args:
write_buckets(List): List of `WriteBucket`,
which includes what to be saved in a checkpoint
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result = []
for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
if non_blocking:
torch.cuda.synchronize()
return result
@staticmethod
@_disable_gc()
def write_preloaded_data_multiproc(
write_buckets: List[WriteBucket], global_results_queue: mp.Queue
rank, write_buckets: List[WriteBucket], global_results_queue: mp.Queue
) -> None:
"""
Performs saving data to storage with multiple processes.
......@@ -186,6 +222,7 @@ class FileSystemWriterAsync(FileSystemWriter):
(or an Exception) from parallel write processes to the main training process
Returns: None
"""
logger = logging.getLogger(__name__)
w_start = time()
write_results_or_exc: Union[dict, Exception] = dict()
ctx = mp.get_context('fork')
......@@ -234,20 +271,16 @@ class FileSystemWriterAsync(FileSystemWriter):
logger.error(err_msg)
write_results_or_exc = local_results_or_exc
break
else:
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()
logger.debug('FileSystemWriterAsync: collected worker results successfully')
global_results_queue.put(write_results_or_exc)
w_end = time()
logger.debug(
f"{w_end}, rank: {torch.distributed.get_rank()},"
f" write(sync,parallel): {w_end - w_start}"
)
logger.debug(f"{w_end}, rank: {rank}," f" write(sync,parallel): {w_end - w_start}")
@staticmethod
@_disable_gc()
......@@ -271,6 +304,8 @@ class FileSystemWriterAsync(FileSystemWriter):
Returns: None, the write result are put into the `queue`
"""
logger = logging.getLogger(__name__)
logger.debug(f'{local_proc_idx} started')
mem_before = _process_memory()
local_results = []
......@@ -288,6 +323,7 @@ class FileSystemWriterAsync(FileSystemWriter):
os.fsync(stream.fileno())
local_output = (local_proc_idx, local_results)
except Exception as e:
logger.debug(f'{local_proc_idx} failed')
local_output = (local_proc_idx, e)
results_queue.put(local_output)
......@@ -334,10 +370,23 @@ class FileSystemWriterAsync(FileSystemWriter):
)
return list(chain.from_iterable(write_results.values()))
def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Instead of assigning indices by plan order, uses PyT rank (same outcome).
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
def _split_by_size_and_type(
bins: int, items: List[WriteItem], separation_hint: Optional[str] = None
) -> List[List[WriteItem]]:
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
return dataclasses.replace(
local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_")
)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Splits write items according to item size into close to uniform bins.
......@@ -353,24 +402,32 @@ def _split_by_size_and_type(
if bins == 1:
return [items]
bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
bytes_items: List[WriteItem] = []
tensor_items: List[WriteItem] = []
for wi in items:
container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items
container.append(wi)
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_items.sort(key=_item_size, reverse=True)
# Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items):
buckets[i % bins].append(item)
# Then, assign tensors according to their sizes
for item in tensor_items:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)
# Sort tensor items by size in decreasing order once and store the size with item
sized_tensors = [(item, _item_size(item)) for item in tensor_items]
sized_tensors.sort(key=itemgetter(1), reverse=True)
# Use a min heap for bin assignment
# Store (total_size_of_bin, bin_index) tuples
heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)]
# Assign tensors using heap
for item, size in sized_tensors:
total_bin_size, bin_idx = heappop(heap)
buckets[bin_idx].append(item)
heappush(heap, (total_bin_size + size, bin_idx))
return buckets
......
......@@ -2,12 +2,13 @@
import logging
from pathlib import Path
from time import time
from typing import Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import Metadata
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing import ShardedObject, ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
......@@ -19,6 +20,7 @@ from megatron.core.dist_checkpointing.exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
exchange_loaded_objects_gather_object,
)
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica
from megatron.core.dist_checkpointing.strategies.base import (
......@@ -26,7 +28,12 @@ from megatron.core.dist_checkpointing.strategies.base import (
LoadShardedStrategy,
SaveShardedStrategy,
)
from megatron.core.dist_checkpointing.utils import _sharded_tensor_shard_id, _ShardId
from megatron.core.dist_checkpointing.utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
_ShardId,
debug_time,
)
from megatron.core.dist_checkpointing.validation import (
determine_global_metadata,
validate_sharding_integrity,
......@@ -34,6 +41,8 @@ from megatron.core.dist_checkpointing.validation import (
logger = logging.getLogger(__name__)
T = TypeVar('T', ShardedObject, ShardedTensor)
class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
"""Wraps arbitrary strategy and distributes the save during `save`.
......@@ -170,7 +179,9 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
self.exchange_algo = exchange_algo
self.cached_distribution: Optional[ShardDistribution] = None
self.cached_global_metadata: Optional[Metadata] = None
@debug_time("FullyParallelLoadStrategyWrapper.load", logger)
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Distributes the load and calls underlying strategy only for parts of the state dict.
......@@ -200,18 +211,20 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
loaded_state_dict = {}
if torch.distributed.get_world_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
# Step 1 and 2: exchange load metadata and distribute the load
start = time()
precomputed_distribution = self.apply_loading_parallelization(sharded_state_dict)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
end = time()
logger.debug(f'self.apply_loading_parallelization took {end - start}s')
start = end
with debug_time("self.apply_loading_parallelization", logger):
precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization(
sharded_state_dict
)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
......@@ -219,88 +232,121 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
(sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = (
self._defer_loading_sharded_tensors(sharded_state_dict)
)
loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir)
end = time()
logger.debug(f'Base load of ShardedObjects took {end - start}s')
start = end
(sharded_objects, sharded_state_dict, to_load_objects, unloaded_objects) = (
self._defer_loading_sharded_objects(sharded_state_dict)
)
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
assert (
len(sharded_state_dict) == 0
), "sharded_state_dict is not empty after deferring tensors and objects"
with debug_time("base_load_ShardedObjects", logger):
# Load sharded objects first
loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir)
with debug_time("base_load_ShardedTensors", logger):
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
with debug_time("self.exchange_loaded_tensors", logger):
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
)
end = time()
logger.debug(f'Base load of ShardedTensors took {end - start}s')
start = end
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
with debug_time("torch.cuda.synchronize", logger):
torch.cuda.synchronize()
all_loaded_objects = exchange_loaded_objects_gather_object(loaded_objects)
if not set(unloaded_objects.keys()).issubset(all_loaded_objects.keys()):
missing_object_shards = set(unloaded_objects.keys()) - all_loaded_objects.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
f'Missing object shards after fully parallel loading: {missing_object_shards}'
)
sync_start = time()
torch.cuda.synchronize()
end = time()
logger.debug(f'torch.cuda.synchronize took {end - sync_start}s')
logger.debug(f'self.exchange_loaded_tensors took {end - start}s')
self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
self.fill_in_deferred_sharded_objects(sharded_objects, all_loaded_objects)
merge(loaded_state_dict, sharded_objects)
merge(loaded_state_dict, sharded_tensors)
if hasattr(self.base_strategy, "cached_global_metadata"):
self.cached_global_metadata = self.base_strategy.cached_global_metadata
return loaded_state_dict
@staticmethod
def _defer_loading_sharded_objects(
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedObject],
Dict[_ShardId, ShardedObject],
]:
return _defer_loading_sharded_items(sharded_state_dict, ShardedObject, _sharded_object_id)
@staticmethod
def _defer_loading_sharded_tensors(
self, sharded_state_dict: ShardedStateDict
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedTensor],
Dict[_ShardId, ShardedTensor],
]:
"""Divides state dict into parts loaded by this vs other ranks.
return _defer_loading_sharded_items(
sharded_state_dict, ShardedTensor, _sharded_tensor_shard_id
)
ShardedTensors with main replica_id will be loaded by this rank,
others will be received by other ranks (after loading from storage).
@staticmethod
def fill_in_deferred_sharded_objects(
sharded_state_dict: ShardedStateDict, loaded_objects: Dict[_ShardId, Any]
) -> None:
"""Fill in objects not loaded by current rank with objects from `loaded_objects` map.
Args:
sharded_state_dict (ShardedStateDict): state dict with ShardedTensor
that will be divided.
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with ShardedTensors
- ShardedStateDict: sub-state dict with non-ShardedTensors
- Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified
by shard ids. This is a mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *this* rank
- Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedObjects are completely replaced with corresponding objects.
loaded_objects (Dict[_ShardId, Any]): dict allowing to map
ShardedObject from the sharded_state_dict to loaded objects.
sharded_tensors, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedTensor)
Returns:
None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_objects, ShardedObject, _sharded_object_id
)
def wrap_non_main_replicas(x):
if isinstance(x, ShardedTensor):
# Assign shard to be loaded or not
if is_main_replica(x.replica_id):
to_load_shards[_sharded_tensor_shard_id(x)] = x
else:
unloaded_shards[_sharded_tensor_shard_id(x)] = x
return x
@staticmethod
def fill_in_deferred_sharded_tensors(
sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
dict_list_map_inplace(wrap_non_main_replicas, sharded_tensors)
return sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards
Returns:
None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_tensors, ShardedTensor, _sharded_tensor_shard_id
)
def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
......@@ -339,34 +385,6 @@ class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
return precomputed_distribution
def fill_in_deferred_sharded_tensors(
self, sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
"""
def fill_in_sharded_tensor(x):
if isinstance(x, ShardedTensor):
try:
x = loaded_tensors[_sharded_tensor_shard_id(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded tensor shard: {_sharded_tensor_shard_id(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_tensor, sharded_state_dict)
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
......@@ -437,3 +455,61 @@ def distribute_main_replicas_with_precomputed_distribution(
sh_ten.replica_id = 0
else:
sh_ten.replica_id = 1
def _defer_loading_sharded_items(
sharded_state_dict: ShardedStateDict, item_type: type, shard_id_func: Callable[[T], _ShardId]
) -> Tuple[ShardedStateDict, ShardedStateDict, Dict[_ShardId, T], Dict[_ShardId, T]]:
"""Divides state dict into parts loaded by this vs other ranks.
Args:
sharded_state_dict (ShardedStateDict): state dict with sharded items
that will be divided.
item_type: The type of sharded item (ShardedObject or ShardedTensor)
shard_id_func: Function to get the shard ID for the item type
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with sharded items
- ShardedStateDict: sub-state dict with non-sharded items
- Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank
- Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_items, remaining_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, item_type)
)
def wrap_non_main_replicas(x: Any) -> Any:
if isinstance(x, item_type):
shard_id = shard_id_func(x)
if is_main_replica(x.replica_id):
to_load_shards[shard_id] = x
else:
unloaded_shards[shard_id] = x
return x
dict_list_map_inplace(wrap_non_main_replicas, sharded_items)
return sharded_items, remaining_state_dict, to_load_shards, unloaded_shards
def _fill_in_deferred_sharded_items(
sharded_state_dict: ShardedStateDict,
loaded_items: Dict[_ShardId, Any],
item_type: type,
shard_id_func: Callable[[T], _ShardId],
) -> None:
"""Helper function to fill in items not loaded by current rank."""
def fill_in_sharded_item(x: Any) -> Any:
if isinstance(x, item_type):
try:
x = loaded_items[shard_id_func(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded item shard: {shard_id_func(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_item, sharded_state_dict)
......@@ -13,7 +13,7 @@ import logging
import math
from dataclasses import dataclass
from itertools import product
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch
......@@ -27,7 +27,6 @@ from megatron.core.dist_checkpointing.dict_utils import (
extract_matching_values,
)
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
StateDict,
......@@ -84,11 +83,7 @@ def is_nd_flattened_tensor(sh_ten: Any) -> bool:
Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
"""
return (
isinstance(sh_ten, ShardedTensor)
and sh_ten.flattened_range is not None
and len(sh_ten.global_shape) > 1
)
return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None
# information needed to restore. With current implementation, this is a nested state dict
......@@ -132,8 +127,12 @@ def apply_nd_flattened_tensors_reformulation(
try:
sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key]
except KeyError as e:
# Handle legacy checkpointing where 1-D flatten tensor metadata was not saved
if len(sh_ten.global_shape) == 1:
return sh_ten
raise CheckpointingException(
f'Missing reformulation metadata for tensor {sh_ten}. Existing keys: {reformulation_metadata.keys()}'
f'Missing reformulation metadata for tensor {sh_ten}. '
f'Existing keys: {reformulation_metadata.keys()}'
) from e
ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape
......@@ -235,13 +234,16 @@ def reformulate_single_nd_flattened_tensor(
):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset)
# `math.ceil` argument is an exact offset of the app next shard expressed in ckpt_local_shape units
# `math.ceil` argument is an exact offset of the app next shard expressed
# in ckpt_local_shape units
next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1))
overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset))
logger.debug(
f'Generated the following number of overlap shards for each dimension: {list(map(len, overlap_dim_offsets))}'
f' for fragmentation ckpt {ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} and chunk offset {sh_ten.local_chunk_offset_in_global()}'
f'Generated the following number of overlap shards for each dimension: '
f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt '
f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} '
f'and chunk offset {sh_ten.local_chunk_offset_in_global()}'
)
reformulated_sh_tens = {}
for chunk_offset in product(*overlap_dim_offsets):
......@@ -286,7 +288,8 @@ def reformulate_single_nd_flattened_tensor(
# For each ckpt shard, we fill the appropriate application shard part
dest_ten = app_non_flat_ten
src_ten = ckpt_ten.view(ckpt_local_shape)
# We don't need narrowing over `prepend_axis_num` axes so we take the [sh_ten.prepend_axis_num:] offsets slice
# We don't need narrowing over `prepend_axis_num` axes so we take
# the [sh_ten.prepend_axis_num:] offsets slice
for (
dim,
offset_for_saved_tensor,
......
......@@ -4,7 +4,7 @@
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -16,19 +16,37 @@ from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
from .torch import MCoreSavePlanner
logger = getLogger(__name__)
from dataclasses import fields
def _compare_dataclasses(obj1, obj2):
if type(obj1) != type(obj2):
return f"Objects are of different types: {type(obj1)} and {type(obj2)}"
differences = []
for field in fields(obj1):
value1 = getattr(obj1, field.name)
value2 = getattr(obj2, field.name)
if value1 != value2:
differences.append(f"{field.name}: {value1} != {value2}")
return differences if differences else "All fields are equal"
def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None,
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]:
loaded_all_plans: Optional[List[SavePlan]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]:
"""
First stage of saving a state dict to storage.
......@@ -62,7 +80,7 @@ def save_state_dict_async_plan(
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- metadata from planning (or None if we reuse cached global metadata)
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
......@@ -80,6 +98,7 @@ def save_state_dict_async_plan(
global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan
global_md_verify_reuse = False
def local_step():
nonlocal local_plan
......@@ -101,11 +120,34 @@ def save_state_dict_async_plan(
return all_local_plans
# Execute local and global planning
# Ideally we want to use the cached plan. Otherwise if the planner and storage_writer
# allow it (`can_run_decentralized_global_plan`) we gather the plans to create
# the metadata but prepare the plans independently on each rank.
# In the worst case we have to reduce_scatter all the plans.
start_plan = time()
if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable")
local_step()
central_plan = cached_central_plan
elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr(
storage_writer, 'can_run_decentralized_global_plan', False
):
local_plan = local_step()
global_md_verify_reuse = verify_global_md_reuse(
loaded_all_plans, local_plan, rank, dist_wrapper
)
if not loaded_all_plans or not global_md_verify_reuse:
all_local_plans = dist_wrapper.gather_object(local_plan)
if dist_wrapper.is_coordinator:
_, global_metadata = planner.create_global_plan(all_local_plans)
global_metadata.all_local_plans = all_local_plans
else:
logger.debug(f"rank: {rank}, Passed cached global metadata")
global_metadata = None
local_plan = planner.create_decentralized_global_plan(local_plan)
local_plan = storage_writer.prepare_decentralized_global_plan(local_plan)
central_plan = local_plan
else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan)
......@@ -118,13 +160,56 @@ def save_state_dict_async_plan(
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return (
(storage_writer, cast(Metadata, global_metadata), dist_wrapper),
(storage_writer, global_metadata, dist_wrapper),
central_plan,
local_plan,
cached_central_plan == central_plan,
global_md_verify_reuse,
)
def verify_global_md_reuse(
loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper
) -> bool:
"""
Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
Args:
loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint
(stored in checkpoint metadata).
local_plan: SavePlan, The local save plan.
rank: Current process rank.
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: True iff the global metadata reuse is possible.
"""
logger.debug(f"verifying reuse of global metadata")
if not loaded_all_plans:
global_md_verify_reuse = False
logger.debug("loaded global metadata reuse verification: no loaded plans passed")
elif len(loaded_all_plans) == dist_wrapper.get_world_size():
local_verify_reuse = all(
getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name)
for f in fields(local_plan)
if f.name != 'storage_data'
)
if not local_verify_reuse:
logger.debug(
f"local_verify_reuse is False: diffs -"
f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}"
)
all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda')
torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN)
# Check if all reduced results are True
global_md_verify_reuse = all_results.item() == 1
else:
global_md_verify_reuse = False
return global_md_verify_reuse
def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
) -> None:
......
File mode changed from 100755 to 100644
......@@ -55,6 +55,7 @@ from .base import (
StrategyAction,
register_default_strategy,
)
from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
from .filesystem_async import FileSystemWriterAsync
from .resharding import (
TensorReformulationMetadata,
......@@ -126,7 +127,9 @@ def flatten_state_dict(
def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None
sh_tens: List[ShardedTensor],
rank: Optional[int] = None,
load_legacy_1d_flatten_tensors: bool = False,
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
......@@ -138,13 +141,12 @@ def sharded_tensor_to_torch_sharded_tensor(
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors:
This function handles 2 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
2. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency:
(1) type are saved according to their original shape.
Type (2) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
......@@ -154,6 +156,8 @@ def sharded_tensor_to_torch_sharded_tensor(
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors
should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
......@@ -163,41 +167,21 @@ def sharded_tensor_to_torch_sharded_tensor(
some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1
for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous()
if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1:
# Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors
has_flattened_range = False
local_global_offsets = {}
prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs)
if is_flattened_range_1d:
# Type (2) case: 1D flattened ShardedTensors
for sh_ten in sh_tens:
assert len(sh_ten.global_offset) == 1, sh_ten
assert sh_ten.prepend_axis_num == 0, sh_ten
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
global_shape = some_sh_ten.global_shape
offsets_shape = (
some_sh_ten.local_shape
) # local shape is not flattened, we need it for chunk offsets
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
[
sh_ten.global_offset[0] + sh_ten.flattened_range.start
], # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
elif has_flattened_range:
if has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
......@@ -250,10 +234,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if is_flattened_range_1d:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
elif has_flattened_range:
if has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
......@@ -270,7 +251,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range and not is_flattened_range_1d:
if has_flattened_range:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
......@@ -296,7 +277,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {}
if has_flattened_range and not is_flattened_range_1d:
if has_flattened_range:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten
......@@ -305,6 +286,7 @@ def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
load_legacy_1d_flatten_tensors: bool = False,
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format.
......@@ -348,7 +330,9 @@ def mcore_to_pyt_state_dict(
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank)
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
sh_tens, rank, load_legacy_1d_flatten_tensors
)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten
......@@ -460,6 +444,7 @@ class MCoreSavePlanner(DefaultSavePlanner):
*args,
dedup_replicated_tensors: Optional[bool] = None,
nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None,
can_run_decentralized_global_plan: bool = True,
**kwargs,
) -> None:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
......@@ -468,6 +453,14 @@ class MCoreSavePlanner(DefaultSavePlanner):
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
self.can_run_decentralized_global_plan = can_run_decentralized_global_plan
if can_run_decentralized_global_plan:
assert (
not dedup_replicated_tensors
), 'Cannot run decentralized plan with dedup_replicated_tensors=True'
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
def create_local_plan(self) -> SavePlan:
"""Adds IOBytes write request on non-coordinator ranks."""
......@@ -503,6 +496,23 @@ class MCoreSavePlanner(DefaultSavePlanner):
metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans)))
return global_plan, metadata
def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Nothing to do, just some checks.
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan'
return local_plan
def transform_object(self, write_item: WriteItem, object: Any):
"""Make no transformations - bytes objects are already serialized."""
return object
......@@ -535,6 +545,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner):
else:
expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if loaded_shape != expected_shape:
if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1:
# Handle legacy 1-D flattened tensors checkpoint format
# where the global shape is not stored in the metadata
expected_shape = sh_ten.global_shape
if loaded_shape == expected_shape:
continue
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor'
......@@ -634,6 +650,8 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
self.separation_hint = separation_hint
self.validated_loaded_metadata_reuse = False
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
......@@ -663,7 +681,14 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
args_cached_plans = None
loaded_all_plans = None
if self.use_cached_ckpt_structure:
loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None)
if loaded_all_plans is None:
logger.debug(
"no all_local_plans in metadata - can't verify global metadata reuse..."
)
args_cached_plans = (
self.cached_central_plan,
self.cached_local_plan,
......@@ -675,24 +700,44 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
self.cached_central_plan,
self.cached_local_plan,
self.validated_cache_reuse,
self.validated_loaded_metadata_reuse,
) = save_state_dict_async_plan(
pyt_state_dict,
writer,
None,
coordinator,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica),
planner=MCoreSavePlanner(
dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False
),
cached_ckpt_structure=args_cached_plans,
loaded_all_plans=loaded_all_plans,
)
rank = torch.distributed.get_rank()
if self.use_cached_ckpt_structure:
if self.validated_cache_reuse:
if (
loaded_all_plans
and self.cached_global_metadata
and self.validated_loaded_metadata_reuse
):
if coordinator == rank:
logger.debug(
f"rank: {rank}, reuse global metadata from loaded"
f" .metadata, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
elif self.validated_cache_reuse:
logger.debug(f"rank: {rank}, cache validated")
if save_state_dict_ret[1]: # when global_metadata is not cached
self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata
# Only Coordinator rank holds cached global_metadata
# (None is returned for global_metadata)
elif coordinator == rank:
logger.debug(f"rank: {rank}, reuse metadata, {save_state_dict_ret[1]}")
logger.debug(
f"rank: {rank}, reuse global metadata cached from previous"
f" save iteration, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
......@@ -700,13 +745,13 @@ class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest:
save_fn_args = writer.get_save_function_and_args()
save_fn, save_args = save_fn_args
save_fn, preload_fn, save_args = save_fn_args
def finalize_fn():
save_state_dict_async_finalize(*save_state_dict_ret)
torch.distributed.barrier()
return AsyncRequest(save_fn, save_args, [finalize_fn])
return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
def can_handle_sharded_objects(self):
return True
......@@ -736,6 +781,12 @@ def get_reformulation_metadata(
'nd_reformulated_orig_global_shape'
]
except KeyError as e:
if len(sh_ten.global_shape) == 1:
warnings.warn(
f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. '
'Skip metadata reformulation.'
)
continue
raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
......@@ -750,6 +801,10 @@ def get_reformulation_metadata(
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format."""
def __init__(self):
self.cached_global_metadata: Optional[Metadata] = None
super().__init__()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
......@@ -761,10 +816,18 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, reformulation_metadata
)
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors = False
for sh_ten in nested_values(sharded_state_dict):
if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata:
has_legacy_1d_flattened_tensors = True
break
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
......@@ -776,15 +839,23 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy):
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
pyt_state_dict = mcore_to_pyt_state_dict(
sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
fsr = CachedMetadataFileSystemReader(checkpoint_dir)
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
fsr,
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
self.cached_global_metadata = (
fsr.read_metadata()
) # no storage interaction thanks to caching
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial, wraps
from itertools import chain
from logging import DEBUG, INFO, StreamHandler, getLogger
from logging import getLogger
from operator import attrgetter, itemgetter
from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy
from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array
from .tensorstore import _load_from_array, open_ts_array
from .zarr import flatten_range, load_zarr_based_sharded_metadata
_import_trigger = None
......@@ -26,9 +25,16 @@ _import_trigger = None
timers = defaultdict(list)
logger = getLogger(__name__)
logger.warning(
'megatron.core.dist_checkpointing.two_stage module is deprecated'
' and will be removed in Megatron-Core v0.12. Please use'
' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.'
)
def timed(verbose=True):
"""Timing decorator."""
def timed_dec(fn):
name = fn.__name__
......@@ -59,6 +65,7 @@ class _ShardedTensorMetadata:
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
"""Id of a sharded tensor."""
return (sharded_tensor.key, sharded_tensor.global_offset)
......@@ -101,6 +108,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
self.global_rank = torch.distributed.get_rank()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Main load method."""
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
......@@ -109,6 +117,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
return sharded_state_dict
def summarize_load_times(self):
"""Summarize load times."""
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
......@@ -124,6 +133,7 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
"""Load tensor from storage."""
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array(
ten_meta.sharded_tensor_no_data,
......@@ -136,12 +146,15 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
@timed()
def maybe_init_gloo_group(self):
"""Create Gloo groups."""
if not self.cpu_transfer:
return
all_groups = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
all_groups = set(tuple(sorted(gr)) for gr in all_groups)
for group_ranks in sorted(all_groups):
# "two_stage" module will be deprecated, so not replace new_group()
# with ...parallel_state.create_group() func setting group_desc here.
gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
if self.global_rank in group_ranks:
self.data_parallel_group = gloo_pg
......@@ -211,7 +224,8 @@ class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
)
logger.debug(
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}\
({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
)
torch.distributed.broadcast(
exchange_tensor, group=self.data_parallel_group, src=src_rank
......
File mode changed from 100755 to 100644
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from nvidia_resiliency_ext.checkpointing.local.base_state_dict import TensorAwareStateDict
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace, merge, nested_values
from .exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
)
from .mapping import ShardedObject, ShardedStateDict, ShardedTensor, StateDict, apply_factory_merges
from .state_dict_utils import load_preprocess, save_preprocess
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
debug_time,
extract_sharded_base,
zip_strict,
)
from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
@dataclass
class MCoreTensorAwareStateDict(TensorAwareStateDict):
"""
MCore-specific class defining the interface between the MCore state dict and checkpoint manager.
This class distinguishes between raw objects, the common state dict, and sharded state dicts
(tensor parts). It also handles optional metadata needed for fully parallel save/load.
"""
common: StateDict
sharded_state_dict: ShardedStateDict
_is_hollow: bool = False
@staticmethod
def _validate_params(algo):
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
@staticmethod
def _get_distribution(
fully_parallel, sharded_part, parallelization_group, cached_distribution=None
):
if fully_parallel:
if cached_distribution is None:
distribution = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
logger.debug(f'MCore_TASD._get_distribution calculated distribution')
else:
distribution = cached_distribution
logger.debug(f'MCore_TASD._get_distribution used cache')
else:
distribution = (None, None, None, None)
logger.debug(f'MCore_TASD._get_distribution returned empty distribution')
return distribution
@staticmethod
def _remove_redundant_data(
fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group
):
if fully_parallel:
for sh_base in nested_values(sharded_part):
# TODO remove redundant objects as well
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_to_saving_rank[shard_id] != torch.distributed.get_rank(
group=parallelization_group
):
sh_base.data = None
@classmethod
@debug_time("from_state_dict", logger)
def from_state_dict(
cls,
sharded_state_dict: ShardedStateDict,
algo: str = 'fully_parallel',
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
cached_metadata: ShardDistribution = None,
) -> Tuple[TensorAwareStateDict, ShardDistribution]:
"""
Constructs a TensorAwareStateDict from a sharded state dictionary.
This method preprocesses the input `sharded_state_dict`, validates parameters,
and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`.
Args:
sharded_state_dict: The input sharded state dictionary to be converted.
algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'.
- 'fully_parallel' enables fully parallel initialization.
parallelization_group (Optional): A distributed process group for parallelization.
cached_metadata (Optional): Precomputed metadata from previous saves.
- Reuses data that doesn't need recalculation, optimizing the creation process.
Returns:
TensorAwareStateDict: An instance initialized with the provided sharded state dictionary
and optional cached metadata.
- The metadata is stored in memory to speed up future saves.
"""
with debug_time("_get_distribution", logger):
cls._validate_params(algo)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(
sharded_state_dict, cached_metadata is None
)
cacheable_distribution = cls._get_distribution(
fully_parallel, sharded_part, parallelization_group, cached_metadata
)
if cacheable_distribution is not None:
shard_to_saving_rank, _, _, _ = cacheable_distribution
cls._remove_redundant_data(
fully_parallel, sharded_part, shard_to_saving_rank, parallelization_group
)
return (
MCoreTensorAwareStateDict(common=common_state_dict, sharded_state_dict=sharded_part),
cacheable_distribution,
)
@property
def is_hollow(self):
"""
True iff tensors had been extracted and have not been inserted back yet.
"""
return self._is_hollow
@property
def _sharded_tensors(self):
# Three possible states for sharded_tensor:
# 1. sharded_tensor with data (.data = tensor)
# 2. sharded_tensor hollow (.data = None, .orig_device = orig_device)
# 3. removed sharded_tensor (.data = None, no device information)
# TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data
if self.is_hollow:
for sh_base in nested_values(self.sharded_state_dict):
# FIXME: Hacky way to store the original device of the popped tensor
if isinstance(sh_base, ShardedTensor) and hasattr(sh_base, 'orig_device'):
yield sh_base
else:
for sh_base in nested_values(self.sharded_state_dict):
if isinstance(sh_base, ShardedTensor) and sh_base.data is not None:
yield sh_base
@property
def tensors(self) -> Iterator[torch.Tensor]:
"""
Get the tensor data from the state dict.
"""
assert not self.is_hollow # TODO raise exception
return map(lambda sh_ten: sh_ten.data, self._sharded_tensors)
@property
def common_state_dict(self) -> Dict:
"""
Get the common state dict from the state dict.
"""
return self.common
def pop_tensors(self) -> List[torch.Tensor]:
"""
Extracts the tensor data from the wrapped state dict, preserving metadata.
Replaces the tensor data in sharded_tensors with device type of extracted tensors.
After this operation, the state dictionary is "hollow", containing no tensor data.
Further calls to `pop_tensor` will raise an error.
@return List of extracted tensors
"""
assert not self.is_hollow # TODO raise exception
result = []
for sh_ten in self._sharded_tensors:
result.append(sh_ten.data)
# FIXME: Hacky way to store the original device, which is not included in the metadata
setattr(sh_ten, 'orig_device', sh_ten.data.device.type)
sh_ten.data = None
self._is_hollow = True
return result
def insert_tensors(self, tensor_data: Iterable[torch.Tensor]):
"""
Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values
Value of `self` is considered to be the same after:
```
self.insert_tensors(self.pop_tensors())
```
"""
assert self.is_hollow # TODO raise exception
for sh_ten, ten in zip_strict(self._sharded_tensors, tensor_data):
# FIXME: Hacky way to store the original device
if sh_ten.orig_device == ten.device.type:
delattr(sh_ten, 'orig_device')
# Tensor might be on non-original device
sh_ten.data = ten
self._is_hollow = False
def init_tensors(self):
"""
Initializes empty tensors with the same properties as the original tensors.
This function should only be called after the original tensors have been popped.
It ensures that the newly created empty tensors match the shape,
dtype, and device of the originals, but contain no data.
"""
assert self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
# Hacky way to retrieve the original device
sh_ten.init_data(sh_ten.orig_device)
delattr(sh_ten, 'orig_device')
self._is_hollow = False
def copy_tensors_to_cpu(self, non_blocking=False):
"""
Stores CPU copies of tensors in the state_dict, replacing the originals,
but without destroying them.
The original devices are remembered for restoration with restore_tensor_device().
Using non_blocking=True allows for asynchronous copying.
"""
assert not self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
if sh_ten.data.device.type == 'cpu':
# Skip cloning if it's already confirmed to be a copy
if not hasattr(sh_ten, 'orig_device'):
sh_ten.data = sh_ten.data.clone()
else:
# FIXME: Hacky way to store the original device
if not hasattr(sh_ten, 'orig_device'):
setattr(sh_ten, 'orig_device', sh_ten.data.device.type)
sh_ten.data = sh_ten.data.detach().to("cpu", non_blocking=non_blocking)
def restore_tensor_device(self, non_blocking=True):
"""
Restores all tensors to their original devices, if a move is required.
Using non_blocking=True allows for asynchronous copying.
"""
assert not self.is_hollow # TODO raise exception
for sh_ten in self._sharded_tensors:
# FIXME: Hacky way to store the original device
if hasattr(sh_ten, 'orig_device'):
sh_ten.data = sh_ten.data.to(sh_ten.orig_device, non_blocking=non_blocking)
delattr(sh_ten, 'orig_device')
def _insert_sharded_data(
self, fully_parallel, sharded_part, parallelization_group, exchange_algo
):
loaded_tensors = {}
for sh_ten in self._sharded_tensors:
loaded_tensors[_sharded_tensor_shard_id(sh_ten)] = sh_ten.data
if fully_parallel:
with debug_time("_get_distribution", logger):
distribution = self._get_distribution(
fully_parallel, sharded_part, parallelization_group
)
if distribution is not None:
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
# TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
with debug_time("exchange_by_distribution", logger):
loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
distribution,
parallelization_group,
exchange_algo,
)
torch.cuda.synchronize()
loaded_objects = {}
for sh_base in nested_values(self.sharded_state_dict):
if not isinstance(sh_base, ShardedTensor):
assert isinstance(sh_base, ShardedObject)
loaded_objects[_sharded_object_id(sh_base)] = sh_base.data
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
assert shard_id in loaded_tensors, (x, shard_id, loaded_tensors.keys())
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part)
@debug_time("to_state_dict", logger)
def to_state_dict(
self,
sharded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Convert tensor-aware dict back to the original state_dict
"""
with debug_time("load_preprocess_and_state_dict_manipulations", logger):
assert not self.is_hollow # TODO raise exception
self._validate_params(algo)
fully_parallel = algo == 'fully_parallel'
# __adding__ common part
recreated_state_dict = dict_list_map_outplace(lambda x: x, self.common)
if not sharded_state_dict:
return recreated_state_dict
# TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
with debug_time("validate_sharding_integrity", logger):
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
with debug_time("_insert_sharded_data", logger):
self._insert_sharded_data(
fully_parallel, sharded_part, parallelization_group, exchange_algo
)
with debug_time("apply_factory_merges", logger):
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
import logging
from contextlib import contextmanager
from time import time
from typing import Dict, Optional, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
......@@ -20,6 +22,18 @@ from .mapping import (
_ShardId = Tuple[str, tuple, Optional[tuple]]
def zip_strict(*args):
"""
Alternative to Python's builtin zip(..., strict=True) (available in 3.10+).
Apart from providing functionality in earlier versions of Python is also more verbose.
(Python's zip does not print lengths, only which iterable has finished earlier)
"""
args = [list(a) for a in args]
lens = [len(a) for a in args]
assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!"
return zip(*args)
def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId:
"""Unique id of the sharded tensor data.
......@@ -217,3 +231,89 @@ def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[
return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
fallback_logger = logging.getLogger(__name__)
__LOGGER_NAME_STACK = []
__LOGGER_STACK = []
@contextmanager
def logger_stack(name: Optional[str] = None, current_logger: Optional[logging.Logger] = None):
"""Context manager for managing logger and name stack.
Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical
logging and contextual logger usage. Ensures the logger stack is restored afterward.
Args:
name (str, optional): Name to add to the logger stack. Defaults to None.
current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in
the stack or a fallback if none exist.
Yields:
Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and
the current logger for the block.
Example:
with logger_stack("scope", logger):
logger.info("Log within 'scope'")
"""
if name:
__LOGGER_NAME_STACK.append(name)
if current_logger:
__LOGGER_STACK.append(current_logger)
last_logger = current_logger
elif __LOGGER_STACK:
last_logger = __LOGGER_STACK[-1]
else:
last_logger = fallback_logger
try:
yield ".".join(__LOGGER_NAME_STACK), last_logger
finally:
if name and __LOGGER_NAME_STACK:
__LOGGER_NAME_STACK.pop(-1)
if current_logger and __LOGGER_STACK:
__LOGGER_STACK.pop(-1)
@contextmanager
def debug_time(
name: str, logger: Optional[logging.Logger] = None, threshold: float = float("-inf"), level=None
):
"""Simple context manager for timing functions/code blocks.
Args:
name (str): Label describing the code being measured.
logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger.
threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster.
level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset;
WARNING otherwise.
"""
with logger_stack(name, logger) as (stacked_name, last_logger):
start = time()
try:
yield
finally:
result = time() - start
if result < threshold:
return
if level is None:
level = logging.DEBUG if threshold == float("-inf") else logging.WARNING
last_logger.log(level, f"{stacked_name} took {result:.4f}s")
def debug_msg(msg: str):
"""Logs a debug message using the current logger stack.
This function formats and logs a debug message with the current logger
and name stack, preserving context from the logger_stack context manager.
Args:
msg (str): The message to be logged at the debug level.
Example:
debug_msg("Checkpoint initialized")
# Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name")
"""
with logger_stack(None, None) as (stacked_name, last_logger):
last_logger.debug(f"{stacked_name} {msg}")
......@@ -412,7 +412,7 @@ def validate_sharding_integrity(
CheckpointingException for invalid access pattern
"""
if common_state_dict:
if common_state_dict is not None:
_validate_common_state_dict(common_state_dict)
if torch.distributed.get_rank() != 0:
......@@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
# For each shard with at least 1 flattened tensor in it, the above
# `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
# The only thing that can go wrong at this point is that some shard don't have
# *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
if not torch.all(shard_access_cnt == 1):
raise CheckpointingException(
f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
)
def _compute_shards_access(rank_sharding):
......@@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
expected_size = np.product(local_shape)
if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
)
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .loss_func import loss_func
from .model_provider import model_provider
from .fully_sharded_data_parallel import FullyShardedDataParallel
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import functools
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from megatron.core import parallel_state
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.distributed.custom_fsdp.param_and_grad_buffer import (
AllGatherPipeline,
BucketingPolicy,
GradReducePipeline,
ParamAndGradBuffer,
PrefetchOrder,
)
from megatron.core.distributed.data_parallel_base import _BaseDataParallel
from megatron.core.distributed.distributed_data_parallel_config import DistributedDataParallelConfig
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import is_submodule, log_single_rank
logger = logging.getLogger(__name__)
class TrainingState(Enum):
"""States of a FSDP parameter group, which are coupled with
the sharding activity of parameters and gradients during training."""
# From pre-forward before post-forward, where parameters should be unsharded
FORWARD = auto()
# Prior to backward computation, where parameters should be unsharded
PRE_BACKWARD = auto()
# After backward computation, where gradients should be re-sharded
POST_BACKWARD = auto()
# Before and after module forward computaton or before pre-backward and
# after post-backward states, where no un/sharding activity happens
IDLE = auto()
class FullyShardedDataParallel(_BaseDataParallel):
"""Fully Sharded Data Parallel training for MCore models.
A distributed training wrapper that shards model parameters, gradients and optimizer
states across data parallel workers. Integrates seamlessly with MCore's tensor
and expert parallelism features.
We supports following modes:
- no_shard: Traditional data parallel training without parameter sharding.
- optim: Shards optimizer states, this is conceptually close to "ZeRO-1", and
main weights for mixed precision training, meanwhile the following `optim_grads`
and `optim_grads_params` will also sharding main weights
during mixed-precision training, omitted without detailed notation.
- optim_grads: Shards gradients and optimizer states, this is conceptually close to "ZeRO-2".
- optim_grads_params: Shards parameters, gradients and optimizer states, this
is conceptually close to "ZeRO-3".
Key Features:
- Compatible with MCore's tensor, context and expert parallelism
- Automatic mixed precision training (BF16/FP8)
- Gradient accumulation and bucketing
- Optimized activation recompute with shard-aware communication: When recomputing
a whole Transformer layer, gather parameters once for both the recomputation
and backward computation
- Compatible with MCore's distributed checkpointing
Args:
config: Transformer config object.
ddp_config: FullyShardedDataParallel config object.
module: Underlying model.
fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
i.e., the minimum releasable model unit. If not provided, defaults to
[TransformerLayer, LanguageModelEmbedding] for GPT-like models.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket.
Examples:
>>> model = GPTModel(config)
>>> model = FullyShardedDataParallel(
... config,
... model,
... ddp_config,
... fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding],
... )
"""
# TODO: add hybrid FSDP (shard model states in a partial DP domain)
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
):
super().__init__(config=config, module=module)
if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
self.module = module
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
self.bucket_size = self.ddp_config.bucket_size
if disable_bucketing:
self.bucket_size = None
self.device = device if device else torch.cuda.current_device()
self.param_to_bucket_group = {}
if fsdp_unit_modules is not None:
self.fsdp_unit_modules = fsdp_unit_modules
else:
self.fsdp_unit_modules = [TransformerLayer]
if not getattr(self.module, "share_embeddings_and_output_weights", False):
self.fsdp_unit_modules.append(LanguageModelEmbedding)
self.main_weights = True
self.data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
self.expert_data_parallel_group = parallel_state.get_expert_data_parallel_group()
# Determine if we should delay the gradient reduction.
self.is_delay_grad_reduce = self.ddp_config.data_parallel_sharding_strategy in [
"no_shard",
"optim",
]
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
assert self.ddp_config.overlap_param_gather
if not self.is_delay_grad_reduce:
assert self.ddp_config.overlap_grad_reduce
self._init_fsdp_param_and_grad_buffer()
self._register_fsdp_hooks(self.module)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
def _init_fsdp_param_and_grad_buffer(self):
if self.config.calculate_per_token_loss:
# We don't need to scale the gradients in this case.
gradient_scaling_factor = None
expert_gradient_scaling_factor = None
else:
if self.ddp_config.average_in_collective:
# FIXME(@jianbinc): Will fix this issue based on Parallel Folding's EDP patch MR.
raise Exception("Not supported")
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Initialize the param and grad buffer.
self.data_parallel_sharding_strategy = self.ddp_config.data_parallel_sharding_strategy
self.param_to_name = {p: name for name, p in self.module.named_parameters()}
self.param_and_grad_buffer = ParamAndGradBuffer(
self.ddp_config,
self.module,
bucketing_policy=BucketingPolicy(
suggested_bucket_size=self.bucket_size,
fsdp_unit_modules=(
# Only when model weights need to be sharded, we need to
# identify the minimum releasable model unit, which is the
# FSDP Unit Module.
self.fsdp_unit_modules
if self.data_parallel_sharding_strategy == "optim_grads_params"
else []
),
data_parallel_sharding_strategy=self.data_parallel_sharding_strategy,
),
data_parallel_group=self.data_parallel_group,
expert_data_parallel_group=self.expert_data_parallel_group,
preserve_fp32_weights=self.ddp_config.preserve_fp32_weights,
grad_reduce_in_fp32=self.ddp_config.grad_reduce_in_fp32,
gradient_scaling_factor=gradient_scaling_factor,
expert_gradient_scaling_factor=expert_gradient_scaling_factor,
device=self.device,
reset_parameters_for_meta_device_init_module=self.config.init_model_with_meta_device,
)
self.param_and_grad_buffer
self.side_stream_for_buffer_copy_and_grad_accum = torch.cuda.Stream()
# Initialize the reduce-scatter pipeline.
self.grad_reduce_pipeline = GradReducePipeline(
self.param_and_grad_buffer, cuda_stream=self.side_stream_for_buffer_copy_and_grad_accum
)
# Initialize the all-gather pipeline.
self.all_gather_pipeline = AllGatherPipeline(self.param_and_grad_buffer)
self.suggested_RS_queue_capacity = self.ddp_config.suggested_communication_unit_size
self.suggested_AG_prefetch_size = self.ddp_config.suggested_communication_unit_size
def _register_fsdp_hooks(self, root_module):
"""Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
This function sets up various hooks required for FSDP operations, including parameter
resharding/unsharding and gradient handling. The registered hooks are:
- Pre-forward hook: Unshards parameters before forward pass
- Post-forward hook: Reshards parameters after forward pass
- Pre-backward hook: Unshards parameters before backward pass
- Post-backward hook: Reshards parameters after backward pass
- Gradient accumulation hook: Handles gradient accumulation and reduction across devices
Args:
root_module: The PyTorch module to register FSDP hooks on
Note:
These hooks are essential for FSDP's memory efficiency as they manage:
1. Dynamic parameter sharding/unsharding to reduce memory footprint
2. Proper gradient synchronization across distributed processes
3. Gradient accumulation for large batch training
Returns:
None
"""
# Initialize module training state.
for m in root_module.modules():
setattr(m, "_training_state", TrainingState.IDLE)
self.forward_pre_hooks = {}
self.forward_hooks = {}
self.backward_pre_hooks = {}
"""
An FSDP unit is a module designed to manage the lifecycle of model parameters
in Fully Sharded Data Parallel (FSDP) training. It ensures that parameters
are only used within the module and are released immediately after
the forward and backward computations are completed.
This approach is crucial for efficient memory management, as releasing
parameters too early can lead to issues if other computations depend on them.
`optim` and `optim_grads` do not require FSDP units because they do not
shard model parameters.
"""
if self.data_parallel_sharding_strategy != "optim_grads_params":
fsdp_unit_modules = []
else:
fsdp_unit_modules = self.fsdp_unit_modules
def release_module_parameters(module, *unused):
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.release_bucket(bucket_id)
if not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
release_params_fp8_transpose_cache(module.parameters())
def release_params_fp8_transpose_cache(params):
for param in params:
if is_float8tensor(param):
param._transpose_invalid = True
param._transpose = None
def all_gather_module_parameters(
module,
*unused,
prefetch=True,
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
wait_bucket_ready=True,
):
wait_list = []
ag_pipeline = self.all_gather_pipeline
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
ag_pipeline.queue_bucket_to_all_gather(
bucket_id,
prefetch=prefetch,
prefetch_order=prefetch_order,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
wait_list.append(bucket_id)
if wait_bucket_ready:
for bucket_id in wait_list:
ag_pipeline.wait_bucket_ready(bucket_id)
def _post_backward(module, *unused):
release_module_parameters(module)
module._training_state = TrainingState.IDLE
def _pre_forward(module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]):
input_training_state = module._training_state
fsdp_forward_prefetch = True
if input_training_state == TrainingState.PRE_BACKWARD:
# In activation recomputation case, we need to cancel forward prefetch.
fsdp_forward_prefetch = False
else:
module._training_state = TrainingState.FORWARD
if isinstance(module, tuple(fsdp_unit_modules)):
wait_list = []
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.queue_bucket_to_all_gather(
bucket_id,
prefetch=fsdp_forward_prefetch,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
wait_list.append(bucket_id)
for bucket_id in wait_list:
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
if not torch.is_grad_enabled():
return args, kwargs
# Register the backward function to release the parameters.
args_list, args_spec = tree_flatten(args)
kwargs_list, kwargs_spec = tree_flatten(kwargs)
args_kwargs_list = list(args_list) + list(kwargs_list)
inp_tensor_indices: List[int] = []
inp_tensors: List[torch.Tensor] = []
for i, obj in enumerate(args_kwargs_list):
if torch.is_tensor(obj) and obj.requires_grad:
inp_tensor_indices.append(i)
inp_tensors.append(obj)
if len(inp_tensors) == 0:
return args, kwargs
inp_tensors = RegisterFSDPBackwardFunction.apply(
functools.partial(_post_backward, module), *inp_tensors
)
for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors):
args_kwargs_list[inp_tensor_idx] = inp_tensor
args_list = args_kwargs_list[: len(args_list)]
kwargs_list = args_kwargs_list[len(args_list) :]
args = tree_unflatten(args_list, args_spec)
kwargs = tree_unflatten(kwargs_list, kwargs_spec)
return args, kwargs
else:
# All-gather the parameters in every forward pass for FSDP.
for param in module.parameters(recurse=False):
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.queue_bucket_to_all_gather(
bucket_id,
prefetch=fsdp_forward_prefetch,
suggested_AG_prefetch_size=self.suggested_AG_prefetch_size,
)
for param in module.parameters(recurse=False):
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
return args, kwargs
if self.ddp_config.overlap_param_gather:
fsdp_modules = []
for name, module in root_module.named_modules():
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
self.forward_pre_hooks[f'module {name} parameter all-gather'] = (
module.register_forward_pre_hook(_pre_forward, prepend=True, with_kwargs=True)
)
def _pre_backward(module: nn.Module, *unused):
module._training_state = TrainingState.PRE_BACKWARD
if isinstance(module, tuple(fsdp_unit_modules)):
all_gather_module_parameters(
module, prefetch_order=PrefetchOrder.BACKWARD_PASS_ORDER
)
def _root_pre_backward(module: nn.Module, *unused):
"""Marks the module's training state as 'pre_backward' before the
backprop, this function is registered on the root module.
This marking enables us to determine whether forward pass needs to
perform reshard/unshard operations in activation recomputation
scenarios.
"""
for module in root_module.modules():
if isinstance(module, tuple(fsdp_unit_modules)):
module._training_state = TrainingState.PRE_BACKWARD
for param in module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
self.all_gather_pipeline.wait_bucket_ready(bucket_id, empty_ok=True)
self.all_gather_pipeline.release_bucket(bucket_id)
def _post_forward(module: nn.Module, input: Any, output: Any):
# When composing with module-hook-based activation checkpointing, the
# post-backward hook is responsible for the reshard
if module._training_state == TrainingState.PRE_BACKWARD:
return output
release_module_parameters(module)
module._training_state = TrainingState.IDLE
return output
def _release_module_fp8_transpose_cache(module: nn.Module, *unused):
release_params_fp8_transpose_cache(module.parameters(recurse=False))
if self.data_parallel_sharding_strategy == "optim_grads_params":
fsdp_modules = []
for name, module in root_module.named_modules():
if any(is_submodule(module, fsdp_module) for fsdp_module in fsdp_modules):
continue
if isinstance(module, tuple(fsdp_unit_modules)):
fsdp_modules.append(module)
self.forward_hooks[f"release module {name} parameters"] = (
module.register_forward_hook(_post_forward, prepend=False)
)
self.backward_pre_hooks[f"all-gather module {name} parameters"] = (
module.register_full_backward_pre_hook(_pre_backward)
)
elif not self.ddp_config.keep_fp8_transpose_cache_when_using_custom_fsdp:
self.forward_hooks[f"remove module {name} fp8 transpose cache"] = (
module.register_forward_hook(
_release_module_fp8_transpose_cache, prepend=False
)
)
self._root_pre_backward_hook_handle = root_module.register_full_backward_pre_hook(
_root_pre_backward
)
def _make_param_hook(param: torch.nn.Parameter):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
wait_previous_grad_reduce = not self.is_delay_grad_reduce
# FIXME: Use insert forward op to replace grad acc hook, which will
# be lost after parameter data movement. For example, module.cuda()
# will cause the registered grad acc hook to be lost.
def param_hook(*unused):
if param.requires_grad:
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
if self.is_delay_grad_reduce:
param.main_grad.add_(param.grad.data)
else:
param.main_grad.copy_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce and (
not self.is_delay_grad_reduce or self.is_last_microbatch
):
gr_pipeline = self.grad_reduce_pipeline
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
gr_pipeline.place_bucket(bucket_id)
go_rs = gr_pipeline.mark_item_ready(param, async_rs=True)
if go_rs and wait_previous_grad_reduce:
gr_pipeline.wait_for_previous_grad_reduce(
recommeded_queue_capacity=self.suggested_RS_queue_capacity
)
return param_hook
# Register backward gradient accumulation hook for each parameter.
self.grad_accs = []
for param in root_module.parameters():
bucket_id = self.param_and_grad_buffer.param_to_param_group[param]
wbuf = self.param_and_grad_buffer.parameter_groups[bucket_id].model_weight_buffer
if param.requires_grad:
if wbuf and wbuf.is_data_distributed:
wbuf.fetch_bucket(and_allocate_params_data=True)
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(_make_param_hook(param))
self.grad_accs.append(grad_acc)
if wbuf and wbuf.is_data_distributed:
wbuf.free_bucket_storage()
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
For grads shard mode there will actually always be gradient sync happening.
"""
# FIXME: Better handling of grads shard mode and no_sync in the training loop so that
# the code doesn't bog down developers.
self.is_last_microbatch = False
try:
yield
finally:
self.is_last_microbatch = True
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if not force_sync and self.ddp_config.overlap_param_gather:
# All-gather the first bucket before the forward pass.
self.all_gather_pipeline.queue_bucket_to_all_gather(bucket_id=0, prefetch=False)
else:
self.all_gather_pipeline.reset()
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.all_gather_bucket_and_set_items(
bucket_id=bucket_id, async_op=True
)
group = self.param_and_grad_buffer.parameter_groups[bucket_id]
if group.model_weight_buffer is None:
continue
if group.model_weight_buffer.is_data_distributed:
# If model weight is sharded, we wait for the all-gather to complete and
# then release the bucket immediately to save memory usage.
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
for bucket_id in range(self.all_gather_pipeline.num_buckets):
self.all_gather_pipeline.wait_bucket_ready(bucket_id)
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if not self.ddp_config.overlap_grad_reduce:
if self.data_parallel_sharding_strategy == "no_shard":
self.param_and_grad_buffer.all_reduce_gradients(
async_op=self.ddp_config.overlap_grad_reduce
)
else:
self.param_and_grad_buffer.reduce_scatter_gradients()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if self.ddp_config.overlap_grad_reduce:
self.grad_reduce_pipeline.wait_for_previous_grad_reduce(0)
self.grad_reduce_pipeline.reset()
else:
self.start_grad_sync()
self.param_and_grad_buffer.update_main_grads()
if self.ddp_config.overlap_param_gather:
self.all_gather_pipeline.reset()
def optimizer_named_parameters(self) -> List[Tuple[str, torch.Tensor]]:
"""
Returns a list of tuples containing the main weights and their corresponding names
for mixed-precision training, to be used by the optimizer for updates.
Returns:
List[Tuple[str, torch.Tensor]]: A list of tuples, where each tuple
contains a main weight tensor and its corresponding name.
"""
return self.param_and_grad_buffer.optimizer_named_parameters
def scale_gradients(self, scaling_factor: float):
"""Scale all gradients inside the buffers by `scaling_factor`."""
self.param_and_grad_buffer.scale_gradients(scaling_factor)
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
self.param_and_grad_buffer.zero_grad()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group(
with_context_parallel=True
)
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
# make a copy of the state_dict to avoid modifying the input state_dict
state_dict = state_dict.copy()
state_dict_extra_states = {}
for key in list(state_dict.keys()):
if key.endswith("_extra_state"):
state_dict_extra_states[key] = state_dict[key]
del state_dict[key]
self.module.load_state_dict(state_dict_extra_states, strict=False)
prefix = "module."
buffer = self.param_and_grad_buffer
for param_groups in buffer.parameter_groups:
wbuf = param_groups.model_weight_buffer
for model_param in wbuf.params:
if is_float8tensor(model_param):
fp8_meta = model_param._fp8_meta['scaling_fwd']
fp8_meta_index = model_param._fp8_meta_index
model_param._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])
param_name = f"{buffer.param_to_name[model_param]}"[len(prefix) :]
if param_name in state_dict:
if wbuf and wbuf.is_data_distributed:
model_param.fully_shard_param_local_shard.data.copy_(
state_dict[param_name]
)
else:
model_param.data.copy_(state_dict[param_name])
del state_dict[param_name]
self.module.load_state_dict(state_dict, strict=False)
return
self.module.load_state_dict(state_dict, strict=strict)
class RegisterFSDPBackwardFunction(torch.autograd.Function):
"""
Register a backward function that will be called after the backward pass
of the model. This function is used to release the parameters after the
backward pass.
"""
@staticmethod
def forward(ctx, post_backward, *inputs: torch.Tensor):
"""
Forward pass of the RegisterFSDPBackwardFunction function.
"""
ctx.post_backward = post_backward
return inputs
@staticmethod
def backward(ctx, *grads: torch.Tensor):
"""
Backward pass of the RegisterFSDPBackwardFunction function.
"""
ctx.post_backward()
return (None,) + grads
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment