Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersistentObject, LocalNonpersitentObject, ShardedTensor
from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -6,8 +6,7 @@ import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from time import time
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
......@@ -15,7 +14,7 @@ import torch
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId
from .utils import _sharded_tensor_shard_id, _ShardId, debug_time
# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
......@@ -52,7 +51,6 @@ class ShardDistribution(NamedTuple):
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""
main_rank_for_shard: Dict[_ShardId, int]
......@@ -237,6 +235,7 @@ def determine_main_replica_uniform_distribution(
@torch.no_grad()
@debug_time(f"exchange_loaded_tensors_gather_rounds", logger)
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
......@@ -276,76 +275,75 @@ def exchange_loaded_tensors_gather_rounds(
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
start = time()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
with debug_time(f"dtype_{dtype}"):
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
del round_tensors # remove tensor references
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
del round_tensors # remove tensor references
return all_loaded_tensors
......@@ -396,7 +394,39 @@ def exchange_loaded_tensors_gather_object(
return all_loaded_tensors
def exchange_loaded_objects_gather_object(
loaded_objects: Dict[_ShardId, Any]
) -> Dict[_ShardId, Any]:
"""Exchange the objects loaded by different ranks with a simple all_gather_object call.
Args:
loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects
already loaded by this rank.
Returns:
Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to
load a given state dict.
"""
all_loaded_objects_list = [None] * torch.distributed.get_world_size(group=None)
torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None)
all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list)
all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list)
# Error checks
if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_objects_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_objects
@torch.no_grad()
@debug_time("exchange_loaded_tensors_broadcast", logger)
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
......@@ -427,8 +457,6 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors = dict(loaded_tensors)
start = time()
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
......@@ -475,17 +503,13 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')
return all_loaded_tensors
def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]:
......@@ -508,6 +532,7 @@ def exchange_by_distribution(
previously loaded tensors (from `loaded_tensors` input)
"""
assert shard_distribution is not None, 'Expecting distribution to perform exchange'
if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds':
......
......@@ -119,7 +119,8 @@ class ShardedTensor(ShardedBase):
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}'
f'Data shape {real_data.shape} doesnt match'
f' expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
......@@ -135,7 +136,13 @@ class ShardedTensor(ShardedBase):
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0:
# NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
# For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
# GPU0 receives p0 and a portion of p1, while GPU1 receives the
# remaining portion of p1 and p2.
# As a result, there is no parameter shard of p2 on GPU0, and
# the shape of p2 on GPU0 is zero.
if sh != 0 and off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
......@@ -515,10 +522,6 @@ class LocalNonpersistentObject:
return self.obj
# TODO: Delete once NeMo fixes typo.
LocalNonpersitentObject = LocalNonpersistentObject
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
......
File mode changed from 100755 to 100644
......@@ -25,7 +25,7 @@ from .mapping import (
StateDict,
apply_factory_merges,
)
from .state_dict_transformation import load_preprocess, save_preprocess
from .state_dict_utils import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
......@@ -104,8 +104,6 @@ def load(
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
""" Utilities for transforming state_dict."""
import logging
from time import time
from typing import Any, Callable, Optional
from typing import Callable, Union
import torch
from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values
from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
CommonStateDict,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
)
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
extract_nonpersistent,
extract_sharded_base,
)
from .utils import extract_nonpersistent, extract_sharded_base
from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
def save_preprocess(
sharded_state_dict: ShardedStateDict,
......@@ -54,6 +40,7 @@ def save_preprocess(
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict)
sharded_part = filter_out_empty_flatten_tensor(sharded_part)
if validate_access_integrity:
preprocessed_common_state_dict = common_state_dict
if preprocess_common_before_consistancy_check:
......@@ -84,6 +71,7 @@ def load_preprocess(sharded_state_dict: ShardedStateDict):
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sharded_state_dict = filter_out_empty_flatten_tensor(sharded_state_dict)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
......@@ -100,171 +88,25 @@ def load_preprocess(sharded_state_dict: ShardedStateDict):
return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories
def prepare_state_dict_for_save(
sharded_state_dict: ShardedStateDict,
async_prepare: bool = False,
algo: str = 'atomic',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
to_cpu: bool = True,
):
"""Creates a tensor-aware state dictionary that can be saved using the Local Checkpoint Manager.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary.
async_prepare (bool): If True, enables asynchronous preparation.
algo (str): The algorithm used to create the tensor-aware state dictionary.
validate_access_integrity (bool): If True, validates sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for exchanges to avoid duplications.
to_cpu (bool): If True, moves all tensors from device to CPU.
Returns:
ShardedStateDict: The tensor-aware state dictionary.
def filter_out_empty_flatten_tensor(sharded_state_dict: Union[dict, list]):
"""
_start = time()
if async_prepare:
raise NotImplementedError('Async state_dict preparation is not yet implemented')
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(sharded_state_dict, validate_access_integrity)
sharded_tensors = []
sharded_objects = []
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
sharded_tensors.append(sh_base)
else:
assert isinstance(sh_base, ShardedObject)
sharded_objects.append(sh_base)
if fully_parallel:
shard_to_saving_rank, _, shard_to_metadata = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
raw_tensors, raw_objects = {}, {}
for ten in sharded_tensors:
shard_id = _sharded_tensor_shard_id(ten)
if not fully_parallel or shard_to_saving_rank[shard_id] == torch.distributed.get_rank():
# TODO cover creating copies on host in CheckpointManager.save()
if to_cpu:
raw_tensors[shard_id] = ten.data.to("cpu", non_blocking=True)
else:
raw_tensors[shard_id] = ten.data
ten.data = None
for obj in sharded_objects:
raw_objects[_sharded_object_id(obj)] = obj.data
obj.data = None
logger.debug(f'prepare_state_dict_for_save took {time() - _start}')
state_dict_for_save = {
'raw_tensors': raw_tensors,
'raw_objects': raw_objects,
'common': common_state_dict,
'sharded_state_dict': sharded_part,
}
if fully_parallel:
state_dict_for_save['shard_to_rank'] = shard_to_saving_rank
state_dict_for_save['shard_to_metadata'] = shard_to_metadata
return state_dict_for_save
def recreate_state_dict_after_load(
sharded_state_dict: ShardedStateDict,
loaded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""Creates a final sharded state dictionary from a tensor-aware state dictionary.
Filter out ShardedTensors with empty flatten_range.
These tensors can cause the PyTorch check in failure.
Args:
sharded_state_dict (ShardedStateDict):
The initial sharded state dictionary generated from the model.
loaded_state_dict (ShardedStateDict):
Tensor-aware state dictionary used to fill in missing data in the sharded state.
algo (str): The algorithm used to reconstruct the state dictionary
from the tensor-aware state dictionary.
exchange_algo (str): The algorithm used for tensor exchanges during retrieval.
validate_access_integrity (bool): If True, performs validation of sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for efficient exchanges during retrieval.
Returns:
ShardedStateDict: The finalized sharded state dictionary.
sharded_state_dict: state dict possibly containing ShardedTensor objects
"""
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
# __adding__ common part
recreated_state_dict, _ = extract_matching_values(loaded_state_dict["common"], lambda x: True)
if not sharded_state_dict:
return recreated_state_dict
# TODO validate laoded_state_dict["sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
# Filter out ShardedTensors with empty flatten_range.
# These tensors can cause the PyTorch check in
# `TorchShardedTensor._init_from_local_shards_and_global_metadata` to fail.
# This situation may occur in custom Fully Sharded Data Parallel (FSDP) cases.
sharded_state_dict, _ = extract_matching_values(
sharded_state_dict,
lambda v: not (
isinstance(v, ShardedTensor)
and v.flattened_range
and v.flattened_range.start == v.flattened_range.stop
),
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
loaded_tensors = loaded_state_dict['raw_tensors']
# TODO cover restoring the original device (H2D) in CheckpointManager.load()
for k, v in loaded_tensors.items():
loaded_tensors[k] = v.cuda() # H2D
if fully_parallel:
distribution = (
loaded_state_dict['shard_to_rank'],
None,
loaded_state_dict['shard_to_metadata'],
)
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
loaded_tensors = exchange_by_distribution(
loaded_tensors, unloaded_shards, distribution, parallelization_group, exchange_algo
)
loaded_objects = loaded_state_dict['raw_objects']
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
if shard_id not in loaded_tensors:
raise Exception(
'The current local checkpoint implementation assumes'
'consistent tensor sharding during load and save operations.'
f'However, the expected shard {x} (ID: {shard_id})'
f'was not found in the checkpoint. (IDs: {loaded_tensors.keys()})'
)
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part)
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
return sharded_state_dict
File mode changed from 100755 to 100644
......@@ -4,17 +4,36 @@
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import gc
import logging
from abc import ABC, abstractmethod
from collections import deque
from time import time
from typing import Callable, List, NamedTuple, Optional, Tuple
from contextlib import contextmanager
from queue import Empty
from time import sleep, time
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
import torch
from torch import multiprocessing as mp
from ..utils import debug_time
logger = logging.getLogger(__name__)
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class AsyncRequest(NamedTuple):
"""Represents an async request that needs to be scheduled for execution.
......@@ -24,12 +43,22 @@ class AsyncRequest(NamedTuple):
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
async_fn_kwargs (Tuple): kwargs to pass to `async_fn`.
preload_fn (Callable): preload function to stage tensors from GPU to Host.
This should be self-contained with a proper list of arguments with `partial`.
is_frozen (Bool): a flag to indicate this async request can be modified or not.
call_idx (int): index variable used to order async requests for synchronization
in preloading and writing tensors on the async caller
"""
async_fn: Optional[Callable]
async_fn_args: Tuple
finalize_fns: List[Callable]
async_fn_kwargs: Dict = {}
preload_fn: Callable = None
is_frozen: bool = False
call_idx: int = 0
def add_finalize_fn(self, fn: Callable) -> None:
"""Adds a new finalize function to the request.
......@@ -66,7 +95,70 @@ class AsyncRequest(NamedTuple):
return self._replace(is_frozen=True)
class DistributedAsyncCaller:
class AsyncCaller(ABC):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
@abstractmethod
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Schedule `async_req` with some process forking or reusing
persistent worker
This method must be called on all ranks.
Args:
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
raise NotImplementedError("This should be implemented")
@abstractmethod
def is_current_async_call_done(self, blocking: bool, no_dist: bool) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
raise NotImplementedError("This should be implemented")
def sync_all_async_calls(self, is_alive: int) -> bool:
"""Check if all ranks have completed async checkpoint writing
Args:
is_alive (bool): if True, the current async request is not completed
Returns:
bool: True if all ranks are done, False if at least one rank is still active.
"""
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten)
return ten[0] == 0
@abstractmethod
def close(self):
"""Terminate the async caller at exit of an application or some termination conditions"""
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
def __del__(self):
self.close()
class TemporalAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
......@@ -76,7 +168,8 @@ class DistributedAsyncCaller:
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None
def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple) -> None:
@_disable_gc()
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
......@@ -84,27 +177,35 @@ class DistributedAsyncCaller:
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
save_args (Tuple): async function args.
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
if async_fn is None:
if async_req.async_fn is None:
return # nothing to do
async_fn_args = list(async_req.async_fn_args)
if async_req.preload_fn:
# If there's a preload_fn in `async_req`, we call this func
# to do the defined action in `async_req.preload_fn` to
# stage GPU tensors to its defined destination
async_fn_args[1] = async_req.preload_fn()
rank = torch.distributed.get_rank()
start_sync = time()
torch.cuda.synchronize()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {end_sync - start_sync} to finish D2H "
)
logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ")
ctx = mp.get_context('fork')
self.start_time = time()
self.process = ctx.Process(target=async_fn, args=save_args)
self.process = ctx.Process(
target=async_req.async_fn, args=async_fn_args, kwargs=async_req.async_fn_kwargs
)
self.process.start()
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} to schedule async ckpt "
)
logger.debug(f"rank: {rank}, takes {init_time - self.start_time} to schedule async ckpt ")
def is_current_async_call_done(self, blocking=False) -> bool:
def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
......@@ -114,31 +215,229 @@ class DistributedAsyncCaller:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead as torch.distributed.barrier (single integer all-reduce)
# The following takes the same overhead
# as torch.distributed.barrier (single integer all-reduce)
is_alive = int(self.process.is_alive()) if self.process is not None else 0
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
if not is_done and blocking:
self.close()
is_done = True
return is_done
def close(self):
if self.process:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
"TemporalAsyncCaller: Async process join finished "
f"after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
class PersistentAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def __init__(self):
self.process: mp.Process = None
self.start_time: Optional[float] = None
ctx = mp.get_context('spawn')
# main queue to deliver `AsyncRequest` from host to the ckpt worker
self.queue: mp.JoinableQueue = ctx.JoinableQueue()
# Queue used to synchronize for the completion of preloading tensors to host
# between a trainer and ckpt worker
self.preload_q: mp.JoinableQueue = ctx.JoinableQueue()
# Queue used to inform trainer when the saving is completed
self.comp_q: mp.Queue = ctx.Queue()
self.cur_item: int = None
self.cur_idx: int = -1
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Put `AsyncRequest` to the Persistent Async Caller
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
async_req (AsyncRequest): `AsyncRequest` object containing to
schedule a checkpointing request
"""
if async_req.async_fn is None:
return # nothing to do
start_sync = end_sync = None
self.start_time = time()
if self.process is None:
ctx = mp.get_context('spawn')
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Starting Async Caller"
)
self.process: mp.Process = ctx.Process(
target=PersistentAsyncCaller.async_loop,
args=(
torch.distributed.get_rank(),
self.queue,
self.preload_q,
self.comp_q,
logger.getEffectiveLevel(),
),
)
self.process.start()
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Started Async Caller"
)
if async_req.preload_fn:
self.preload_q.put(async_req.call_idx)
self.queue.put(async_req)
logger.debug(f"rank: {torch.distributed.get_rank()}, put {async_req.call_idx}")
if async_req.preload_fn:
start_sync = time()
# Synchronize for pre-staging tensors
self.preload_q.join()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, "
f"takes {end_sync - start_sync} to finish D2H "
)
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, DistributedAsyncCaller is_alive: {is_alive}"
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} "
"to schedule async ckpt "
)
torch.distributed.all_reduce(ten)
if ten[0] > 0 and not blocking:
return False
else:
if self.process is not None:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
return True
def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
is_alive: bool = False
if self.process:
while self.cur_item is None:
try:
# Retrieve comp call_idx without waiting
self.cur_item = self.comp_q.get_nowait()
except Empty:
# This method is called after any `AsyncRequest` is pushed to the main loop
# So, the background writing is still active
# before the worker put call_idx to `comp_q`
if not blocking:
is_alive = True
break
sleep(0.1)
if self.cur_item is not None:
logger.debug(
f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}"
f" is completed, {is_alive}"
)
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
# This is set to False when blocking == False so this routine is called again
# to simply call `sync_all_async_calls` to check if other ranks complete the writing
if is_done:
# The current request is completed globally. Reset the current item for polling.
logger.debug(
f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}"
f" is completed globally, {is_done}"
)
self.cur_item = None
return is_done
def close(self):
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
)
if self.process:
self.queue.put('DONE')
self.queue.join()
self.process.join()
self.process = None
@staticmethod
@_disable_gc()
def async_loop(
rank: int,
queue: mp.JoinableQueue,
preload_q: mp.JoinableQueue,
comp_q: mp.Queue,
log_level: int = logging.INFO,
):
"""Main function for the persistent checkpoint worker
The persisent worker is created once and terminated at exit or
when application calls `close()` explictily
This routine receives `AsyncRequest` and does `preload_fn` first and
put the integer value in `preload_q` to inform the trainer to proceed.
When the `async_fn` from the request` is completed (background saving is done),
it puts a integer value to `comp_q` to notify the trainer the completion.
Args:
rank (int): the rank of the trainer where the persistent worker is created.
queue (mp.JoinableQueue): the main queue used to receive `AsyncRequest
from the training rank
preload_q (mp.JoinableQueue): a queue to inform trainer that preloading of tensors
from GPU to Host or dedicated location is completed
comp_q (mp.Queue): a queue to inform the training rank the completion of scheduled
async checkpoint request
log_level (int, Optional): an integer to set log-level in this spawned process
to get aligned with the training rank's logging level
"""
logger = logging.getLogger(__name__)
logger.setLevel(log_level)
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has started")
while True:
item = queue.get()
if isinstance(item, str) and item == 'DONE':
queue.task_done()
break
elif isinstance(item, AsyncRequest):
async_fn_args = list(item.async_fn_args)
if item.preload_fn:
call_idx = preload_q.get()
# the 2nd arg is state dict
async_fn_args[1] = item.preload_fn()
logger.debug(f"{rank} has completed D2H of {call_idx}")
preload_q.task_done()
item.async_fn(*async_fn_args, **item.async_fn_kwargs)
logger.debug(f"{rank} has completed saving {item.call_idx}")
comp_q.put(item.call_idx)
queue.task_done()
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has terminated")
class _ActiveAsyncRequest(NamedTuple):
......@@ -152,7 +451,7 @@ class _ActiveAsyncRequest(NamedTuple):
"""
idx: int
async_caller: DistributedAsyncCaller
async_caller: AsyncCaller
async_request: AsyncRequest
......@@ -163,9 +462,18 @@ class AsyncCallsQueue:
active calls with `maybe_finalize_async_calls`.
"""
def __init__(self):
def __init__(self, persistent: bool = False):
self.async_calls: deque[_ActiveAsyncRequest] = deque([])
self.call_idx: int = -1
self.persistent: bool = persistent
self.persistent_caller: AsyncCaller = None
def _get_async_caller(self):
if not self.persistent:
return TemporalAsyncCaller()
if self.persistent_caller is None:
self.persistent_caller = PersistentAsyncCaller()
return self.persistent_caller
def schedule_async_request(self, async_request: AsyncRequest) -> int:
"""Start a new async call and add it to a queue of active async calls.
......@@ -180,13 +488,20 @@ class AsyncCallsQueue:
This can help the user keep track of the async calls.
"""
self.call_idx += 1
async_caller = DistributedAsyncCaller()
async_caller = self._get_async_caller()
# Backward compatibility for local checkpointing built with the old AsyncRequest
if len(async_request._fields) != len(AsyncRequest._fields):
async_request = AsyncRequest(**async_request._asdict())
async_request = async_request._replace(call_idx=self.call_idx)
finalize_fns = async_request.finalize_fns
async_request = async_request._replace(finalize_fns=None)
async_request = async_request.freeze()
async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
async_caller.schedule_async_call(async_request)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, finalize_fns))
return self.call_idx
def maybe_finalize_async_calls(self, blocking=False) -> List[int]:
def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
"""Finalizes all available calls.
This method must be called on all ranks.
......@@ -201,18 +516,20 @@ class AsyncCallsQueue:
"""
call_idx_finalized = []
while self.async_calls:
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking)
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(
blocking, no_dist
)
if not next_async_done:
break
call_idx, _, async_request = self.async_calls.popleft()
for finalize_fn in async_request.finalize_fns:
finalize_fn()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert (
ten.item() == call_idx
), 'Unmatched async calls. That probably means not all ranks are participating in async finalization'
call_idx_finalized.append(call_idx)
with debug_time("finalize", logger):
call_idx, _, finalize_fns = self.async_calls.popleft()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert ten.item() == call_idx, 'Unmatched async calls. '
'That probably means not all ranks are participating in async finalization'
for finalize_fn in finalize_fns:
finalize_fn()
call_idx_finalized.append(call_idx)
return call_idx_finalized
def get_num_unfinalized_calls(self):
......@@ -222,3 +539,5 @@ class AsyncCallsQueue:
def close(self):
"""Finalize all calls upon closing."""
self.maybe_finalize_async_calls(blocking=True)
if self.persistent and self.persistent_caller:
self.persistent_caller.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment