Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersistentObject, LocalNonpersitentObject, ShardedTensor from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor
from .serialization import ( from .serialization import (
load, load,
load_common_state_dict, load_common_state_dict,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -6,8 +6,7 @@ import logging ...@@ -6,8 +6,7 @@ import logging
from collections import defaultdict from collections import defaultdict
from functools import reduce from functools import reduce
from itertools import zip_longest from itertools import zip_longest
from time import time from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np import numpy as np
import torch import torch
...@@ -15,7 +14,7 @@ import torch ...@@ -15,7 +14,7 @@ import torch
from .core import CheckpointingException from .core import CheckpointingException
from .dict_utils import nested_values from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId from .utils import _sharded_tensor_shard_id, _ShardId, debug_time
# TODO: remove TE references once the TE bug is fixed # TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class # Check if Transformer Engine has Float8Tensor class
...@@ -52,7 +51,6 @@ class ShardDistribution(NamedTuple): ...@@ -52,7 +51,6 @@ class ShardDistribution(NamedTuple):
identifier to the original ShardedTensor identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group need a given shard in a given parallelization group
""" """
main_rank_for_shard: Dict[_ShardId, int] main_rank_for_shard: Dict[_ShardId, int]
...@@ -237,6 +235,7 @@ def determine_main_replica_uniform_distribution( ...@@ -237,6 +235,7 @@ def determine_main_replica_uniform_distribution(
@torch.no_grad() @torch.no_grad()
@debug_time(f"exchange_loaded_tensors_gather_rounds", logger)
def exchange_loaded_tensors_gather_rounds( def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor], loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor], unloaded_shards: Dict[_ShardId, ShardedTensor],
...@@ -276,7 +275,7 @@ def exchange_loaded_tensors_gather_rounds( ...@@ -276,7 +275,7 @@ def exchange_loaded_tensors_gather_rounds(
# Group by dtype so that we all_gather tensors of the same dtype # Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str): for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
start = time() with debug_time(f"dtype_{dtype}"):
# shards_by_rank maps rank to tensors loaded by this rank # shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [ shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group)) [] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
...@@ -310,7 +309,10 @@ def exchange_loaded_tensors_gather_rounds( ...@@ -310,7 +309,10 @@ def exchange_loaded_tensors_gather_rounds(
else: else:
assert isinstance(shard_id, tuple), type(shard_id) assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank: if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys()) assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
orig_device = all_loaded_tensors[shard_id] orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda() all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id] local_ten = all_loaded_tensors[shard_id]
...@@ -343,10 +345,6 @@ def exchange_loaded_tensors_gather_rounds( ...@@ -343,10 +345,6 @@ def exchange_loaded_tensors_gather_rounds(
del round_tensors # remove tensor references del round_tensors # remove tensor references
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
return all_loaded_tensors return all_loaded_tensors
...@@ -396,7 +394,39 @@ def exchange_loaded_tensors_gather_object( ...@@ -396,7 +394,39 @@ def exchange_loaded_tensors_gather_object(
return all_loaded_tensors return all_loaded_tensors
def exchange_loaded_objects_gather_object(
loaded_objects: Dict[_ShardId, Any]
) -> Dict[_ShardId, Any]:
"""Exchange the objects loaded by different ranks with a simple all_gather_object call.
Args:
loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects
already loaded by this rank.
Returns:
Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to
load a given state dict.
"""
all_loaded_objects_list = [None] * torch.distributed.get_world_size(group=None)
torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None)
all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list)
all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list)
# Error checks
if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_objects_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_objects
@torch.no_grad() @torch.no_grad()
@debug_time("exchange_loaded_tensors_broadcast", logger)
def exchange_loaded_tensors_broadcast( def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor], loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor], unloaded_shards: Dict[_ShardId, ShardedTensor],
...@@ -427,8 +457,6 @@ def exchange_loaded_tensors_broadcast( ...@@ -427,8 +457,6 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors = dict(loaded_tensors) all_loaded_tensors = dict(loaded_tensors)
start = time()
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()): for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1: if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], ( assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
...@@ -475,17 +503,13 @@ def exchange_loaded_tensors_broadcast( ...@@ -475,17 +503,13 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors[shard_id] = local_ten.to(orig_device) all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten del local_ten
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')
return all_loaded_tensors return all_loaded_tensors
def exchange_by_distribution( def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor], loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor], unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None, shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None, parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast', exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]: ) -> Dict[_ShardId, torch.Tensor]:
...@@ -508,6 +532,7 @@ def exchange_by_distribution( ...@@ -508,6 +532,7 @@ def exchange_by_distribution(
previously loaded tensors (from `loaded_tensors` input) previously loaded tensors (from `loaded_tensors` input)
""" """
assert shard_distribution is not None, 'Expecting distribution to perform exchange'
if exchange_algo == 'gather_object': if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds': elif exchange_algo == 'gather_rounds':
......
...@@ -119,7 +119,8 @@ class ShardedTensor(ShardedBase): ...@@ -119,7 +119,8 @@ class ShardedTensor(ShardedBase):
self.init_data(device='meta') self.init_data(device='meta')
if self.data.shape != real_data.shape: if self.data.shape != real_data.shape:
raise CheckpointingException( raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}' f'Data shape {real_data.shape} doesnt match'
f' expected {self.data.shape} for {self}'
) )
finally: finally:
self.data = real_data self.data = real_data
...@@ -135,7 +136,13 @@ class ShardedTensor(ShardedBase): ...@@ -135,7 +136,13 @@ class ShardedTensor(ShardedBase):
) )
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape): for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0: # NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
# For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
# GPU0 receives p0 and a portion of p1, while GPU1 receives the
# remaining portion of p1 and p2.
# As a result, there is no parameter shard of p2 on GPU0, and
# the shape of p2 on GPU0 is zero.
if sh != 0 and off % sh != 0:
raise CheckpointingException( raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.' f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
) )
...@@ -515,10 +522,6 @@ class LocalNonpersistentObject: ...@@ -515,10 +522,6 @@ class LocalNonpersistentObject:
return self.obj return self.obj
# TODO: Delete once NeMo fixes typo.
LocalNonpersitentObject = LocalNonpersistentObject
@dataclass @dataclass
class ShardedObject(ShardedBase): class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object. """Represents a mapping between a local object and a global object.
......
File mode changed from 100755 to 100644
...@@ -25,7 +25,7 @@ from .mapping import ( ...@@ -25,7 +25,7 @@ from .mapping import (
StateDict, StateDict,
apply_factory_merges, apply_factory_merges,
) )
from .state_dict_transformation import load_preprocess, save_preprocess from .state_dict_utils import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest from .strategies.async_utils import AsyncRequest
from .strategies.base import ( from .strategies.base import (
AsyncSaveShardedStrategy, AsyncSaveShardedStrategy,
...@@ -104,8 +104,6 @@ def load( ...@@ -104,8 +104,6 @@ def load(
checkpoint_dir = Path(checkpoint_dir) checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir) common_state_dict = common_strategy.load_common(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess( sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict sharded_state_dict
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation.""" """ Utilities for transforming state_dict."""
import logging from typing import Callable, Union
from time import time
from typing import Any, Callable, Optional
import torch from .dict_utils import dict_list_map_inplace, extract_matching_values
from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values
from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution
from .mapping import ( from .mapping import (
CommonStateDict, CommonStateDict,
ShardedObject,
ShardedStateDict, ShardedStateDict,
ShardedTensor, ShardedTensor,
ShardedTensorFactory, ShardedTensorFactory,
StateDict, StateDict,
apply_factories, apply_factories,
apply_factory_merges,
)
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
extract_nonpersistent,
extract_sharded_base,
) )
from .utils import extract_nonpersistent, extract_sharded_base
from .validation import determine_global_metadata, validate_sharding_integrity from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
def save_preprocess( def save_preprocess(
sharded_state_dict: ShardedStateDict, sharded_state_dict: ShardedStateDict,
...@@ -54,6 +40,7 @@ def save_preprocess( ...@@ -54,6 +40,7 @@ def save_preprocess(
apply_factories(sharded_state_dict) apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict) _, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict) sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict)
sharded_part = filter_out_empty_flatten_tensor(sharded_part)
if validate_access_integrity: if validate_access_integrity:
preprocessed_common_state_dict = common_state_dict preprocessed_common_state_dict = common_state_dict
if preprocess_common_before_consistancy_check: if preprocess_common_before_consistancy_check:
...@@ -84,6 +71,7 @@ def load_preprocess(sharded_state_dict: ShardedStateDict): ...@@ -84,6 +71,7 @@ def load_preprocess(sharded_state_dict: ShardedStateDict):
# Create a copy of sharded_state_dict as the passed in state dict may have # Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated # references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True) sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sharded_state_dict = filter_out_empty_flatten_tensor(sharded_state_dict)
sh_ten_factories, _ = extract_matching_values( sh_ten_factories, _ = extract_matching_values(
sharded_state_dict, sharded_state_dict,
...@@ -100,171 +88,25 @@ def load_preprocess(sharded_state_dict: ShardedStateDict): ...@@ -100,171 +88,25 @@ def load_preprocess(sharded_state_dict: ShardedStateDict):
return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories
def prepare_state_dict_for_save( def filter_out_empty_flatten_tensor(sharded_state_dict: Union[dict, list]):
sharded_state_dict: ShardedStateDict,
async_prepare: bool = False,
algo: str = 'atomic',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
to_cpu: bool = True,
):
"""Creates a tensor-aware state dictionary that can be saved using the Local Checkpoint Manager.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary.
async_prepare (bool): If True, enables asynchronous preparation.
algo (str): The algorithm used to create the tensor-aware state dictionary.
validate_access_integrity (bool): If True, validates sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for exchanges to avoid duplications.
to_cpu (bool): If True, moves all tensors from device to CPU.
Returns:
ShardedStateDict: The tensor-aware state dictionary.
""" """
Filter out ShardedTensors with empty flatten_range.
_start = time() These tensors can cause the PyTorch check in failure.
if async_prepare:
raise NotImplementedError('Async state_dict preparation is not yet implemented')
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(sharded_state_dict, validate_access_integrity)
sharded_tensors = []
sharded_objects = []
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
sharded_tensors.append(sh_base)
else:
assert isinstance(sh_base, ShardedObject)
sharded_objects.append(sh_base)
if fully_parallel:
shard_to_saving_rank, _, shard_to_metadata = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
raw_tensors, raw_objects = {}, {}
for ten in sharded_tensors:
shard_id = _sharded_tensor_shard_id(ten)
if not fully_parallel or shard_to_saving_rank[shard_id] == torch.distributed.get_rank():
# TODO cover creating copies on host in CheckpointManager.save()
if to_cpu:
raw_tensors[shard_id] = ten.data.to("cpu", non_blocking=True)
else:
raw_tensors[shard_id] = ten.data
ten.data = None
for obj in sharded_objects:
raw_objects[_sharded_object_id(obj)] = obj.data
obj.data = None
logger.debug(f'prepare_state_dict_for_save took {time() - _start}')
state_dict_for_save = {
'raw_tensors': raw_tensors,
'raw_objects': raw_objects,
'common': common_state_dict,
'sharded_state_dict': sharded_part,
}
if fully_parallel:
state_dict_for_save['shard_to_rank'] = shard_to_saving_rank
state_dict_for_save['shard_to_metadata'] = shard_to_metadata
return state_dict_for_save
def recreate_state_dict_after_load(
sharded_state_dict: ShardedStateDict,
loaded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""Creates a final sharded state dictionary from a tensor-aware state dictionary.
Args: Args:
sharded_state_dict (ShardedStateDict): sharded_state_dict: state dict possibly containing ShardedTensor objects
The initial sharded state dictionary generated from the model.
loaded_state_dict (ShardedStateDict):
Tensor-aware state dictionary used to fill in missing data in the sharded state.
algo (str): The algorithm used to reconstruct the state dictionary
from the tensor-aware state dictionary.
exchange_algo (str): The algorithm used for tensor exchanges during retrieval.
validate_access_integrity (bool): If True, performs validation of sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for efficient exchanges during retrieval.
Returns:
ShardedStateDict: The finalized sharded state dictionary.
""" """
# Filter out ShardedTensors with empty flatten_range.
if algo != 'atomic' and algo != 'fully_parallel': # These tensors can cause the PyTorch check in
raise NotImplementedError( # `TorchShardedTensor._init_from_local_shards_and_global_metadata` to fail.
'Only "atomic" and "fully_parallel" sharding algorithms are supported.' # This situation may occur in custom Fully Sharded Data Parallel (FSDP) cases.
) sharded_state_dict, _ = extract_matching_values(
fully_parallel = algo == 'fully_parallel' sharded_state_dict,
lambda v: not (
# __adding__ common part isinstance(v, ShardedTensor)
recreated_state_dict, _ = extract_matching_values(loaded_state_dict["common"], lambda x: True) and v.flattened_range
and v.flattened_range.start == v.flattened_range.stop
if not sharded_state_dict: ),
return recreated_state_dict
# TODO validate laoded_state_dict["sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
loaded_tensors = loaded_state_dict['raw_tensors']
# TODO cover restoring the original device (H2D) in CheckpointManager.load()
for k, v in loaded_tensors.items():
loaded_tensors[k] = v.cuda() # H2D
if fully_parallel:
distribution = (
loaded_state_dict['shard_to_rank'],
None,
loaded_state_dict['shard_to_metadata'],
)
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
loaded_tensors = exchange_by_distribution(
loaded_tensors, unloaded_shards, distribution, parallelization_group, exchange_algo
)
loaded_objects = loaded_state_dict['raw_objects']
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
if shard_id not in loaded_tensors:
raise Exception(
'The current local checkpoint implementation assumes'
'consistent tensor sharding during load and save operations.'
f'However, the expected shard {x} (ID: {shard_id})'
f'was not found in the checkpoint. (IDs: {loaded_tensors.keys()})'
) )
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part) return sharded_state_dict
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment