Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersistentObject, LocalNonpersitentObject, ShardedTensor
from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -6,8 +6,7 @@ import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from time import time
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
......@@ -15,7 +14,7 @@ import torch
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId
from .utils import _sharded_tensor_shard_id, _ShardId, debug_time
# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
......@@ -52,7 +51,6 @@ class ShardDistribution(NamedTuple):
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""
main_rank_for_shard: Dict[_ShardId, int]
......@@ -237,6 +235,7 @@ def determine_main_replica_uniform_distribution(
@torch.no_grad()
@debug_time(f"exchange_loaded_tensors_gather_rounds", logger)
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
......@@ -276,76 +275,75 @@ def exchange_loaded_tensors_gather_rounds(
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
start = time()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
with debug_time(f"dtype_{dtype}"):
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
del round_tensors # remove tensor references
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
del round_tensors # remove tensor references
return all_loaded_tensors
......@@ -396,7 +394,39 @@ def exchange_loaded_tensors_gather_object(
return all_loaded_tensors
def exchange_loaded_objects_gather_object(
loaded_objects: Dict[_ShardId, Any]
) -> Dict[_ShardId, Any]:
"""Exchange the objects loaded by different ranks with a simple all_gather_object call.
Args:
loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects
already loaded by this rank.
Returns:
Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to
load a given state dict.
"""
all_loaded_objects_list = [None] * torch.distributed.get_world_size(group=None)
torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None)
all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list)
all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list)
# Error checks
if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_objects_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_objects
@torch.no_grad()
@debug_time("exchange_loaded_tensors_broadcast", logger)
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
......@@ -427,8 +457,6 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors = dict(loaded_tensors)
start = time()
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
......@@ -475,17 +503,13 @@ def exchange_loaded_tensors_broadcast(
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')
return all_loaded_tensors
def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]:
......@@ -508,6 +532,7 @@ def exchange_by_distribution(
previously loaded tensors (from `loaded_tensors` input)
"""
assert shard_distribution is not None, 'Expecting distribution to perform exchange'
if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds':
......
......@@ -119,7 +119,8 @@ class ShardedTensor(ShardedBase):
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}'
f'Data shape {real_data.shape} doesnt match'
f' expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
......@@ -135,7 +136,13 @@ class ShardedTensor(ShardedBase):
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0:
# NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
# For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
# GPU0 receives p0 and a portion of p1, while GPU1 receives the
# remaining portion of p1 and p2.
# As a result, there is no parameter shard of p2 on GPU0, and
# the shape of p2 on GPU0 is zero.
if sh != 0 and off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
......@@ -515,10 +522,6 @@ class LocalNonpersistentObject:
return self.obj
# TODO: Delete once NeMo fixes typo.
LocalNonpersitentObject = LocalNonpersistentObject
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
......
File mode changed from 100755 to 100644
......@@ -25,7 +25,7 @@ from .mapping import (
StateDict,
apply_factory_merges,
)
from .state_dict_transformation import load_preprocess, save_preprocess
from .state_dict_utils import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
......@@ -104,8 +104,6 @@ def load(
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
""" Utilities for transforming state_dict."""
import logging
from time import time
from typing import Any, Callable, Optional
from typing import Callable, Union
import torch
from .dict_utils import dict_list_map_inplace, extract_matching_values, merge, nested_values
from .exchange_utils import determine_main_replica_uniform_distribution, exchange_by_distribution
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
CommonStateDict,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
)
from .utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
extract_nonpersistent,
extract_sharded_base,
)
from .utils import extract_nonpersistent, extract_sharded_base
from .validation import determine_global_metadata, validate_sharding_integrity
logger = logging.getLogger(__name__)
def save_preprocess(
sharded_state_dict: ShardedStateDict,
......@@ -54,6 +40,7 @@ def save_preprocess(
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict)
sharded_part = filter_out_empty_flatten_tensor(sharded_part)
if validate_access_integrity:
preprocessed_common_state_dict = common_state_dict
if preprocess_common_before_consistancy_check:
......@@ -84,6 +71,7 @@ def load_preprocess(sharded_state_dict: ShardedStateDict):
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sharded_state_dict = filter_out_empty_flatten_tensor(sharded_state_dict)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
......@@ -100,171 +88,25 @@ def load_preprocess(sharded_state_dict: ShardedStateDict):
return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories
def prepare_state_dict_for_save(
sharded_state_dict: ShardedStateDict,
async_prepare: bool = False,
algo: str = 'atomic',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
to_cpu: bool = True,
):
"""Creates a tensor-aware state dictionary that can be saved using the Local Checkpoint Manager.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary.
async_prepare (bool): If True, enables asynchronous preparation.
algo (str): The algorithm used to create the tensor-aware state dictionary.
validate_access_integrity (bool): If True, validates sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for exchanges to avoid duplications.
to_cpu (bool): If True, moves all tensors from device to CPU.
Returns:
ShardedStateDict: The tensor-aware state dictionary.
def filter_out_empty_flatten_tensor(sharded_state_dict: Union[dict, list]):
"""
_start = time()
if async_prepare:
raise NotImplementedError('Async state_dict preparation is not yet implemented')
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
sharded_part, common_state_dict = save_preprocess(sharded_state_dict, validate_access_integrity)
sharded_tensors = []
sharded_objects = []
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
sharded_tensors.append(sh_base)
else:
assert isinstance(sh_base, ShardedObject)
sharded_objects.append(sh_base)
if fully_parallel:
shard_to_saving_rank, _, shard_to_metadata = determine_main_replica_uniform_distribution(
sharded_part, parallelization_group, True
)
raw_tensors, raw_objects = {}, {}
for ten in sharded_tensors:
shard_id = _sharded_tensor_shard_id(ten)
if not fully_parallel or shard_to_saving_rank[shard_id] == torch.distributed.get_rank():
# TODO cover creating copies on host in CheckpointManager.save()
if to_cpu:
raw_tensors[shard_id] = ten.data.to("cpu", non_blocking=True)
else:
raw_tensors[shard_id] = ten.data
ten.data = None
for obj in sharded_objects:
raw_objects[_sharded_object_id(obj)] = obj.data
obj.data = None
logger.debug(f'prepare_state_dict_for_save took {time() - _start}')
state_dict_for_save = {
'raw_tensors': raw_tensors,
'raw_objects': raw_objects,
'common': common_state_dict,
'sharded_state_dict': sharded_part,
}
if fully_parallel:
state_dict_for_save['shard_to_rank'] = shard_to_saving_rank
state_dict_for_save['shard_to_metadata'] = shard_to_metadata
return state_dict_for_save
def recreate_state_dict_after_load(
sharded_state_dict: ShardedStateDict,
loaded_state_dict: ShardedStateDict,
algo: str = 'atomic',
exchange_algo: str = 'broadcast',
validate_access_integrity: bool = True,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""Creates a final sharded state dictionary from a tensor-aware state dictionary.
Filter out ShardedTensors with empty flatten_range.
These tensors can cause the PyTorch check in failure.
Args:
sharded_state_dict (ShardedStateDict):
The initial sharded state dictionary generated from the model.
loaded_state_dict (ShardedStateDict):
Tensor-aware state dictionary used to fill in missing data in the sharded state.
algo (str): The algorithm used to reconstruct the state dictionary
from the tensor-aware state dictionary.
exchange_algo (str): The algorithm used for tensor exchanges during retrieval.
validate_access_integrity (bool): If True, performs validation of sharding integrity.
parallelization_group (torch.distributed.ProcessGroup):
The process group used for efficient exchanges during retrieval.
Returns:
ShardedStateDict: The finalized sharded state dictionary.
sharded_state_dict: state dict possibly containing ShardedTensor objects
"""
if algo != 'atomic' and algo != 'fully_parallel':
raise NotImplementedError(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
fully_parallel = algo == 'fully_parallel'
# __adding__ common part
recreated_state_dict, _ = extract_matching_values(loaded_state_dict["common"], lambda x: True)
if not sharded_state_dict:
return recreated_state_dict
# TODO validate laoded_state_dict["sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
# Filter out ShardedTensors with empty flatten_range.
# These tensors can cause the PyTorch check in
# `TorchShardedTensor._init_from_local_shards_and_global_metadata` to fail.
# This situation may occur in custom Fully Sharded Data Parallel (FSDP) cases.
sharded_state_dict, _ = extract_matching_values(
sharded_state_dict,
lambda v: not (
isinstance(v, ShardedTensor)
and v.flattened_range
and v.flattened_range.start == v.flattened_range.stop
),
)
# __adding__ nonpersistent part
merge(recreated_state_dict, nonpersistent_state_dict)
sharded_part, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(determine_global_metadata(sharded_part)[1])
# load sharded tensors and sharded objects to sharded_part
loaded_tensors = loaded_state_dict['raw_tensors']
# TODO cover restoring the original device (H2D) in CheckpointManager.load()
for k, v in loaded_tensors.items():
loaded_tensors[k] = v.cuda() # H2D
if fully_parallel:
distribution = (
loaded_state_dict['shard_to_rank'],
None,
loaded_state_dict['shard_to_metadata'],
)
unloaded_shards = {}
for sh_base in nested_values(sharded_part):
if isinstance(sh_base, ShardedTensor):
shard_id = _sharded_tensor_shard_id(sh_base)
if shard_id not in loaded_tensors:
unloaded_shards[shard_id] = sh_base
loaded_tensors = exchange_by_distribution(
loaded_tensors, unloaded_shards, distribution, parallelization_group, exchange_algo
)
loaded_objects = loaded_state_dict['raw_objects']
def load_sharded_base(x: Any):
if isinstance(x, ShardedTensor):
shard_id = _sharded_tensor_shard_id(x)
if shard_id not in loaded_tensors:
raise Exception(
'The current local checkpoint implementation assumes'
'consistent tensor sharding during load and save operations.'
f'However, the expected shard {x} (ID: {shard_id})'
f'was not found in the checkpoint. (IDs: {loaded_tensors.keys()})'
)
x = loaded_tensors[shard_id]
if isinstance(x, ShardedObject):
object_id = _sharded_object_id(x)
assert object_id in loaded_objects, (x, object_id, loaded_objects.keys())
x = loaded_objects[object_id]
return x
dict_list_map_inplace(load_sharded_base, sharded_part)
sharded_part = apply_factory_merges(sharded_part, sh_ten_factories)
# __adding__ sharded_part
merge(recreated_state_dict, sharded_part)
return recreated_state_dict
return sharded_state_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment