Commit e532679c authored by oahzxl's avatar oahzxl
Browse files

Merge branch 'main' of https://github.com/oahzxl/ColossalAI into chunk

parents c1492e50 7d5640b9
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
from .utils import merge_same_dim_mesh_list
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
......@@ -23,7 +23,7 @@ class _DimSpec:
This class is used internally in ShardingSpec.
Argument:
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
Otherwise, the element in shard_list means the data will be sharded in that dimension.
'''
......@@ -62,7 +62,7 @@ class _DimSpec:
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
......@@ -159,9 +159,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
class ShardingSpec:
'''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
to, the entire shape of the tensor before sharded, and the sharding sequence looks like
[R, R, S0, S1].
Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
......@@ -176,12 +176,19 @@ class ShardingSpec:
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
if isinstance(entire_shape, (list, tuple)):
entire_shape = torch.Size(entire_shape)
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None:
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
dim_partition_dict=self.dim_partition_dict)
self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self.convert_shard_sequence_to_dict()
self._sanity_check()
......@@ -260,10 +267,10 @@ class ShardingSpec:
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
......
from dataclasses import dataclass
from typing import Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from .compute_spec import ComputeSpec
from colossalai.tensor import ProcessGroup
from dataclasses import dataclass
@dataclass
class ColoTensorSpec:
""" ColoTensorSpec
A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
......
import torch
from typing import Dict, Iterator, List, Tuple, Union
from typing import Iterator, Tuple, Union
import torch
import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor
......@@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list,
e.g.:
e.g.:
all-gather(S01) -> S0
Argument:
......@@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
e.g.:
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
......@@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
......@@ -89,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims):
return shard_list_list
def mix_gather_simulator(f_target_pair, b_target_pair):
'''
Assume index of f and b target pairs are 'f' and 'b'
S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0)
S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1)
S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1)
RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1)
S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0)
RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0)
'''
if f_target_pair[1] and b_target_pair[1]:
leading_dim = b_target_pair[1] > f_target_pair[1]
return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)]
if f_target_pair[1]:
leading_dim = f_target_pair[1][0] < f_target_pair[1][1]
return [
f_target_pair[0],
], [int(leading_dim), int(leading_dim)]
if b_target_pair[1]:
leading_dim = b_target_pair[1][0] < b_target_pair[1][1]
return [
b_target_pair[0],
], [int(leading_dim), int(leading_dim)]
# The function is credited to PyTorch Team
def named_params_with_colotensor(
module: nn.Module,
......@@ -164,3 +190,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to convert the negative dim value to positive.
'''
dims_to_convert = []
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dims_to_convert.append(dim)
for dim in dims_to_convert:
dim_partition_dict.pop(dim)
dim_partition_dict[dim_size + dim] = mesh_list
return dim_partition_dict
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to merge the different key value which points to same physical position.
For example:
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
'''
converted_dim_partition_dict = {}
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dim = dim_size + dim
if dim not in converted_dim_partition_dict:
converted_dim_partition_dict[dim] = mesh_list
else:
converted_dim_partition_dict[dim].extend(mesh_list)
return converted_dim_partition_dict
......@@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
def assert_equal(a: Tensor, b: Tensor):
......@@ -12,12 +13,8 @@ def assert_not_equal(a: Tensor, b: Tensor):
assert not torch.all(a == b), f'expected a and b to be not equal but they are, {a} vs {b}'
def assert_close(a: Tensor, b: Tensor, rtol: float = 1e-5, atol: float = 1e-8):
assert torch.allclose(a, b, rtol=rtol, atol=atol), f'expected a and b to be close but they are not, {a} vs {b}'
def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol, atol)
assert_close(a, b, rtol=rtol, atol=atol)
def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
......@@ -30,4 +27,4 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
for i in range(world_size - 1):
a = tensor_list[i]
b = tensor_list[i + 1]
assert torch.all(a == b), f'expected tensors on rank {i} and {i+1} to be equal but they are not, {a} vs {b}'
assert torch.all(a == b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}'
import random
import numpy as np
import torch
def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.trainer.hooks import BaseHook
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
@HOOKS.register_module
class MemTraceHook(BaseHook):
"""Save memory stats and pass it to states
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""
def __init__(
self,
priority: int = 0,
) -> None:
super().__init__(priority=priority)
self._memory_monitor = AsyncMemoryMonitor()
def after_hook_is_attached(self, trainer):
# Initialize the data
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
def before_train_iter(self, trainer):
self._memory_monitor.start()
return super().before_train_iter(trainer)
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_train_iter(trainer, output, label, loss)
def before_test_iter(self, trainer):
self._memory_monitor.start()
return super().before_test(trainer)
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_test_iter(trainer, output, label, loss)
from .io import load, merge, redist, save
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Type
from .reader import CheckpointReader, DiskCheckpointReader
from .writer import CheckpointWriter, DiskCheckpointWriter
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
def register(name: str):
assert name not in _backends, f'"{name}" is registered'
def wrapper(cls):
_backends[name] = cls
return cls
return wrapper
def get_backend(name: str) -> 'CheckpointIOBackend':
assert name in _backends, f'Unsupported backend "{name}"'
return _backends[name]()
class CheckpointIOBackend(ABC):
def __init__(self) -> None:
super().__init__()
self.temps: List[str] = []
@abstractmethod
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
pass
@abstractmethod
def get_reader(self, base_name: str) -> CheckpointReader:
pass
@abstractmethod
def get_temp(self, base_name: str) -> str:
pass
@abstractmethod
def clean_temp(self) -> None:
pass
@register('disk')
class CheckpointDiskIO(CheckpointIOBackend):
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
def get_reader(self, base_name: str) -> CheckpointReader:
return DiskCheckpointReader(base_name)
def get_temp(self, base_name: str) -> str:
temp_dir_name = tempfile.mkdtemp(dir=base_name)
self.temps.append(temp_dir_name)
return temp_dir_name
def clean_temp(self) -> None:
for temp_dir_name in self.temps:
shutil.rmtree(temp_dir_name)
import re
GLOBAL_META_FILE_NAME = 'global_meta.bin'
MODEL_CKPT_FILE_NAME = 'model.bin'
OPTIM_CKPT_FILE_NAME = 'optim.bin'
META_CKPT_FILE_NAME = 'meta.bin'
OTHER_CKPT_FILE_NAME = 'other.bin'
CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from torch import Tensor
from .distributed import merge_param, unmerge_param
from .meta import ParamDistMeta, RedistMeta
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
class CheckpointConvertor(ABC):
@abstractmethod
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
pass
@abstractmethod
def complete(self) -> None:
pass
class ModelCheckpointConvertor(CheckpointConvertor):
def __init__(self, param_count: Dict[str, int]) -> None:
super().__init__()
self.param_count = param_count
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
@abstractmethod
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
pass
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
for k, tensor in state_dict.items():
self.buffer[k][rank] = tensor
converted_keys = set()
for k, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[k]:
tensors = []
dist_metas = []
for rank, tensor in rank_dict.items():
tensors.append(tensor)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][k])
self.convert_tensors(k, tensors, dist_metas)
converted_keys.add(k)
for k in converted_keys:
del self.buffer[k]
def complete(self) -> None:
assert len(self.buffer) == 0
class ModelCheckpointMerger(ModelCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
super().__init__(param_count)
self.sharder = ModelCheckpointSharder(max_shard_size)
self.save_fn = save_fn
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
shard = self.sharder.append(key, tensor)
run_if_not_none(self.save_fn, shard)
def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())
class ModelCheckpointRedistor(ModelCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
redist_meta: RedistMeta) -> None:
super().__init__(param_count)
self.save_fns = save_fns
self.redist_meta = redist_meta
nprocs = len(save_fns)
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
if len(dist_metas) == 0:
# already global
tensor = tensors[0]
else:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[key][tp_rank][dp_rank]:
shard = self.sharders[rank].append(key, t)
run_if_not_none(self.save_fns[rank], shard)
def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())
class OptimizerCheckpointConvertor(CheckpointConvertor):
def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__()
self.param_count = param_count
self.param_to_os = param_to_os
self.paired_os = paired_os
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
self.os_to_param = {v: k for k, v in param_to_os.items()}
@abstractmethod
def setup(self, param_groups: dict) -> None:
pass
@abstractmethod
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
pass
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
self.setup(state_dict['param_groups'])
for idx, state in state_dict['state'].items():
self.buffer[idx][rank] = state
converted_indices = set()
for idx, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
states = []
dist_metas = []
for rank, state in rank_dict.items():
states.append(state)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
self.convert_states(idx, states, dist_metas)
converted_indices.add(idx)
for idx in converted_indices:
del self.buffer[idx]
def complete(self) -> None:
assert len(self.buffer) == 0
class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fn = save_fn
self.sharder = None
def setup(self, param_groups: dict) -> None:
if self.sharder is None:
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(states)
new_state = {}
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
else:
new_state[state_key] = state_tensor
shard = self.sharder.append(idx, new_state)
run_if_not_none(self.save_fn, shard)
def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())
class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
redist_meta: RedistMeta) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fns = save_fns
self.redist_meta = redist_meta
self.sharders: List[OptimizerCheckpointSharder] = []
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
def setup(self, param_groups: dict) -> None:
if len(self.sharders) == 0:
nprocs = len(self.save_fns)
for _ in range(nprocs):
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
need_merge: bool = True
if len(dist_metas) == 0:
need_merge = False
else:
assert len(dist_metas) == len(states)
new_states = [{} for _ in range(len(self.save_fns))]
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
if need_merge:
tensor = merge_param([state[state_key] for state in states], dist_metas)
else:
tensor = state_tensor
for tp_rank, tensor_list in enumerate(
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
new_states[rank][state_key] = t
else:
for new_state in new_states:
new_state[state_key] = state_tensor
for rank, new_state in enumerate(new_states):
shard = self.sharders[rank].append(idx, new_state)
run_if_not_none(self.save_fns[rank], shard)
def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())
import torch
from numpy import prod
from torch import Tensor
from typing import List, Optional, Tuple
from collections import defaultdict
from .meta import ParamDistMeta, ParamRedistMeta
def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
for dist_meta in dist_metas[1:]:
assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
if not dist_metas[0].used_zero:
# tensors are replicate
return tensors[0]
numel = dist_metas[0].zero_numel
orig_shape = dist_metas[0].zero_orig_shape
tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
return torch.cat(tensors).reshape(orig_shape)
def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
for dist_meta in dist_metas[1:]:
assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
for t in tensors[1:]:
assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
if not dist_metas[0].used_tp:
# tensors are replicate
return tensors[0]
total_parts = prod(dist_meta.tp_num_parts)
assert dist_meta.tp_world_size == total_parts, \
f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
for dim, num_parts in shard_info:
buffer = []
for start in range(0, len(tensors), num_parts):
buffer.append(torch.cat(tensors[start:start + num_parts], dim))
tensors = buffer
assert len(tensors) == 1
return tensors[0]
def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) > 0
# check world size
for dist_meta in dist_metas[1:]:
assert dist_meta.dp_world_size == dist_metas[
0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
assert dist_meta.tp_world_size == dist_metas[
0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
def deduplicate_params(tensors: List[Tensor],
dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
unique_dist_meta = []
unique_idx = []
for i, dist_meta in enumerate(dist_metas):
if dist_meta not in unique_dist_meta:
unique_dist_meta.append(dist_meta)
unique_idx.append(i)
return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
# validate parallel info
validate_parallel_info(dist_metas)
tensors, dist_metas = deduplicate_params(tensors, dist_metas)
unflattened_tensors = []
# group zero params by tp rank
tensor_dict = defaultdict(list)
dist_meta_dict = defaultdict(list)
for t, dist_meta in zip(tensors, dist_metas):
tensor_dict[dist_meta.tp_rank].append(t)
dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
assert len(tensor_dict
) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
for tp_rank in tensor_dict.keys():
unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
if not redist_meta.used_tp:
assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
return [tensor]
total_parts = prod(redist_meta.tp_num_parts)
assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
tensors = [tensor]
for dim, num_parts in shard_info:
buffer = []
for t in tensors:
assert t.size(dim) % num_parts == 0, \
f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
buffer.extend(chunks)
tensors = buffer
assert len(tensors) == redist_meta.tp_world_size
return tensors
def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
if not redist_meta.used_zero:
return [tensor] * redist_meta.dp_world_size
tensors: List[Optional[Tensor]] = [
torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
]
offsets = redist_meta.zero_offsets + [tensor.numel()]
for i, offset in enumerate(offsets[:-1]):
end = offsets[i + 1]
tensors.append(tensor.view(-1)[offset:end])
if len(tensors) < redist_meta.dp_world_size:
tensors.extend([
torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for _ in range(redist_meta.dp_world_size - len(tensors))
])
assert len(tensors) == redist_meta.dp_world_size
return tensors
def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
tensors = split_tp_param(tensor, redist_meta)
tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
return tensors
import warnings
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import torch.distributed as dist
from torch.nn import Module
from torch.optim import Optimizer
from .backend import get_backend
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
OptimizerCheckpointRedistor)
from .meta import ParamDistMeta, RedistMeta
from .utils import build_checkpoints, optimizer_load_state_dict
def save(path: str,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk',
**kwargs: Any) -> None:
io_backend = get_backend(backend)
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
if world_size == 1:
# global doesn't need dist_meta
dist_meta = None
else:
assert dist_meta is not None
max_shard_size = int(max_shard_size_gb * 1024**3)
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
param_to_os, dist_meta)
writer = io_backend.get_writer(path, overwrite, rank, world_size)
writer.save_others(kwargs)
for model_checkpoint in model_checkpoints:
writer.save_model(model_checkpoint)
for optimizer_checkpoint in optimizer_checkpoints:
writer.save_optimizer(optimizer_checkpoint)
writer.save_meta(meta_checkpoint)
def merge(path: str,
output_path: str,
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk') -> bool:
io_backend = get_backend(backend)
if dist.is_initialized() and dist.get_rank() != 0:
return False
reader = io_backend.get_reader(path)
if len(reader.meta_list) == 1:
# already global
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
return False
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
writer = io_backend.get_writer(output_path, overwrite=overwrite)
writer.save_others(reader.load_others())
max_shard_size = int(max_shard_size_gb * 1024**3)
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
dist_meta_list)
_convert_shards(
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
reader.load_optimizers(), dist_meta_list)
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
if param_to_os is not None:
meta_checkpoint['param_to_os'] = param_to_os
meta_checkpoint['paired_os'] = paired_os
writer.save_meta(meta_checkpoint)
return True
def redist(path: str,
output_path: str,
redist_meta: RedistMeta,
dist_metas: List[Dict[str, ParamDistMeta]],
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk') -> bool:
io_backend = get_backend(backend)
if dist.is_initialized() and dist.get_rank() != 0:
return False
nprocs = len(dist_metas)
reader = io_backend.get_reader(path)
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
do_redist: bool = False
if len(dist_meta_list) == nprocs:
for a, b in zip(dist_metas, dist_meta_list):
if a != b:
do_redist = True
break
else:
do_redist = True
if not do_redist:
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
return False
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
writers[0].save_others(reader.load_others())
max_shard_size = int(max_shard_size_gb * 1024**3)
_convert_shards(
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
reader.load_models(), dist_meta_list)
_convert_shards(
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
for writer, dist_meta in zip(writers, dist_metas):
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
if param_to_os is not None:
meta_checkpoint['param_to_os'] = param_to_os
meta_checkpoint['paired_os'] = paired_os
writer.save_meta(meta_checkpoint)
return True
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for shard_dict in shard_generator:
convertor.append(shard_dict, dist_meta_list)
convertor.complete()
def load(path: str,
model: Module,
optimizer: Optional[Optimizer] = None,
redist_meta: Optional[RedistMeta] = None,
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
max_shard_size_gb: float = 0.0,
backend: str = 'disk') -> dict:
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
rank: int = dist.get_rank() if dist.is_initialized() else 0
is_main_process: bool = rank == 0
# validate args
if redist_meta is None or dist_metas is None:
assert is_global
io_backend = get_backend(backend)
read_path: str = path
if is_main_process:
# pre-process checkpoints
temp_path = io_backend.get_temp(path)
if is_global:
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
else:
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
if wrote:
read_path = temp_path
if not is_global:
bcast_list = [read_path] if is_main_process else [None]
dist.broadcast_object_list(bcast_list)
read_path = bcast_list[0]
reader = io_backend.get_reader(read_path)
# load model
for shard in reader.load_model(rank):
model.load_state_dict(shard, strict=False)
if optimizer is not None:
for shard in reader.load_optimizer(rank):
# optimizer.load_state_dict(shard)
optimizer_load_state_dict(optimizer, shard)
others_dict = reader.load_others()
if not is_global:
dist.barrier()
# clean up temp
if is_main_process:
io_backend.clean_temp()
return others_dict
from dataclasses import dataclass
from typing import List, Optional, Set, Dict
@dataclass
class ParamDistMeta:
# parallel info
dp_rank: int
dp_world_size: int
tp_rank: int
tp_world_size: int
# tp info
tp_shard_dims: Optional[List[int]] = None
tp_num_parts: Optional[List[int]] = None
# zero info
zero_numel: Optional[int] = None
zero_orig_shape: Optional[List[int]] = None
@property
def used_tp(self) -> bool:
return self.tp_shard_dims is not None and self.tp_num_parts is not None
@property
def used_zero(self) -> bool:
return self.zero_numel is not None and self.zero_orig_shape is not None
@property
def parallel_meta(self) -> tuple:
return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
@property
def tp_meta(self) -> tuple:
return self.tp_shard_dims, self.tp_num_parts
@property
def zero_meta(self) -> tuple:
return self.zero_numel, self.zero_orig_shape
@staticmethod
def from_dict(d: dict) -> 'ParamDistMeta':
return ParamDistMeta(**d)
@dataclass
class ParamRedistMeta:
# parallel info
dp_world_size: int
tp_world_size: int
# tp info
tp_shard_dims: Optional[List[int]] = None
tp_num_parts: Optional[List[int]] = None
# zero info
zero_start_dp_rank: Optional[int] = None
zero_offsets: Optional[List[int]] = None
@property
def used_tp(self) -> bool:
return self.tp_shard_dims is not None and self.tp_num_parts is not None
@property
def used_zero(self) -> bool:
return self.zero_start_dp_rank is not None and self.zero_offsets is not None
@dataclass
class RankRedistMeta:
dp_rank: int
tp_rank: int
pp_rank: int
@dataclass
class PipelineRedistMeta:
params: Set[str]
@dataclass
class RedistMeta:
rank_meta: Dict[str, Dict[int, RankRedistMeta]]
pipeline_meta: List[PipelineRedistMeta]
param_meta: Dict[str, ParamRedistMeta]
import os
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict, Generator, List, Optional, Tuple
import torch
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
from .meta import ParamDistMeta
from .utils import is_duplicated_list
class CheckpointReader(ABC):
def __init__(self, base_name: str) -> None:
super().__init__()
self.base_name = base_name
self.meta_list = []
@abstractmethod
def read(self, name: str) -> dict:
pass
@abstractmethod
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
pass
@abstractmethod
def load_model(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_models(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_others(self) -> dict:
pass
class DiskCheckpointReader(CheckpointReader):
def __init__(self, base_name: str) -> None:
super().__init__(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
global_meta = self.read(GLOBAL_META_FILE_NAME)
for meta_file_name in global_meta['meta']:
meta = self.read(meta_file_name)
if meta.get('dist_meta', None) is None:
# only global checkpoint can have empty dist_meta
assert len(global_meta['meta']) == 1
self.meta_list.append(meta)
def read(self, name: str) -> dict:
return torch.load(os.path.join(self.base_name, name))
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
None), meta.get('paired_os', None))
for meta in self.meta_list]
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
# reduce param_count
param_count = Counter(p for params in params_list for p in params)
# validate param_to_os
assert is_duplicated_list(param_to_os_list)
assert is_duplicated_list(paired_os_list)
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
meta = self.meta_list[rank]
checkpoint_names = meta.get(shard_type, [])
for name in checkpoint_names:
yield self.read(name)
def load_model(self, rank: int) -> Generator[dict, None, None]:
return self._load_shard('model', rank)
def load_models(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
model_checkpoint_names = meta.get('model', [])
if indices[i] < len(model_checkpoint_names):
shards[i] = self.read(model_checkpoint_names[indices[i]])
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
param_groups = None
for shard in self._load_shard('optimizer', rank):
if param_groups is None:
param_groups = shard['param_groups']
else:
shard['param_groups'] = param_groups
yield shard
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
param_groups = []
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
optimizer_checkpoint_names = meta.get('optimizer', [])
if indices[i] < len(optimizer_checkpoint_names):
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
if indices[i] == 0:
param_groups.append(shards[i]['param_groups'])
else:
shards[i]['param_groups'] = param_groups[i]
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_others(self) -> dict:
return self.read(OTHER_CKPT_FILE_NAME)
import warnings
from copy import deepcopy
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple
from torch import Tensor
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from .meta import ParamDistMeta
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
if arg is not None:
return fn(arg)
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
# ensure all params in optimizer are in model state dict
params_set = set(id(p) for p in model.parameters())
for group in optimizer.param_groups:
for p in group['params']:
assert id(p) in params_set
param_mappings = {}
start_index = 0
def get_group_mapping(group):
nonlocal start_index
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
start_index += len(group['params'])
for g in optimizer.param_groups:
get_group_mapping(g)
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
size = 0
for v in state.values():
if isinstance(v, Tensor):
size += v.numel() * v.element_size()
return size
class ModelCheckpointSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, Tensor] = {}
self.buffer_size: int = 0
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {}
self.buffer_size = 0
self.buffer[key] = tensor
self.buffer_size += tensor.numel() * tensor.element_size()
return retval
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
shards = []
for key, tensor in state_dict.items():
shard = self.append(key, tensor)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer) > 0 else None
class OptimizerCheckpointSharder:
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
self.buffer_size: int = 0
self.returned_first: bool = False
def append(self, key: int, state: dict) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {'state': {}}
self.buffer_size = 0
self.buffer['state'][key] = state
self.buffer_size += compute_optimizer_state_size(state)
return retval
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
shards = []
for key, state in state_dict['state'].items():
shard = self.append(key, state)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer['state']) > 0 else None
def shard_checkpoint(max_shard_size: int,
model_state_dict: Dict[str, Tensor],
optimizer_state_dict: Optional[dict] = None,
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
has_optimizer: bool = False
if optimizer_state_dict is not None:
assert param_to_os is not None
os_to_param = {v: k for k, v in param_to_os.items()}
for os_key in optimizer_state_dict['state'].keys():
assert os_key in os_to_param
assert os_to_param[os_key] in model_state_dict
has_optimizer = True
model_sharder = ModelCheckpointSharder(max_shard_size)
model_shards = model_sharder.extend(model_state_dict)
run_if_not_none(model_shards.append, model_sharder.complete())
if not has_optimizer:
return model_shards, []
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
return model_shards, optimizer_shards
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
os_to_param = {v: k for k, v in param_to_os.items()}
paired_os = {}
for idx, state in optimizer_state_dict['state'].items():
paired_os[idx] = {}
p = model_state_dict[os_to_param[idx]]
for k, v in state.items():
if isinstance(v, Tensor) and v.shape == p.shape:
paired_os[idx][k] = True
else:
paired_os[idx][k] = False
return paired_os
def build_checkpoints(max_size: int,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
save_global = dist_meta is None
model_state_dict = model.state_dict()
optimizer_state_dict = optimizer.state_dict() if optimizer else None
meta = {'dist_meta': dist_meta}
if optimizer:
param_to_os = param_to_os or get_param_to_os(model, optimizer)
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
meta['param_to_os'] = param_to_os
meta['paired_os'] = paired_os
if not save_global and eliminate_replica:
# filter dp replicated params
model_state_dict = {
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
if optimizer:
optimizer_state_dict['state'] = {
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
for k in model_state_dict.keys()
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
meta['params'] = list(model_state_dict.keys())
if len(model_state_dict) == 0:
warnings.warn('model state dict is empty, checkpoint is not saved')
return [], [], meta
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
param_to_os)
return model_checkpoints, optimizer_checkpoints, meta
def is_duplicated_list(list_: List[Any]) -> bool:
if len(list_) == 0:
return True
elem = list_[0]
for x in list_[1:]:
if x != elem:
return False
return True
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
for k, v in src_state.items():
if k in dest_state:
old_v = dest_state[k]
if isinstance(old_v, Tensor):
old_v.copy_(v)
else:
dest_state[k] = v
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
state_dict = deepcopy(state_dict)
groups = optimizer.param_groups
saved_groups = state_dict['param_groups']
idx_to_p: Dict[str, Parameter] = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
unexpected_keys = []
error_msgs = []
for idx, state in state_dict['state'].items():
if idx in idx_to_p:
old_state = optimizer.state[idx_to_p[idx]]
copy_optimizer_state(state, old_state)
else:
unexpected_keys.append(idx)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
"\n\t".join(error_msgs)))
from abc import ABC, abstractmethod
from typing import Optional
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
import torch
import os
class CheckpointWriter(ABC):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__()
self.base_name = base_name
self.overwrite = overwrite
self.rank = rank
self.world_size = world_size
self.is_distributed = world_size > 1
self.is_main_process = rank == 0
@abstractmethod
def write(self, name: str, state_dict: dict) -> None:
pass
@abstractmethod
def save_model(self, model_checkpoint: dict) -> None:
pass
@abstractmethod
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
pass
@abstractmethod
def save_meta(self, meta_checkpoint: dict) -> None:
pass
@abstractmethod
def save_others(self, kwargs: dict) -> None:
pass
class DiskCheckpointWriter(CheckpointWriter):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__(base_name, overwrite, rank, world_size)
if not os.path.exists(base_name):
os.makedirs(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
self.model_checkpoint_names = []
self.optimizer_checkpoint_names = []
self.is_meta_saved: bool = False
self._save_global_meta()
def write(self, name: str, state_dict: dict) -> None:
path = os.path.join(self.base_name, name)
if os.path.exists(path) and not self.overwrite:
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
torch.save(state_dict, path)
def _save_global_meta(self) -> None:
if self.is_main_process:
global_meta = {'meta': []}
if self.is_distributed:
for i in range(self.world_size):
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
else:
global_meta['meta'].append(META_CKPT_FILE_NAME)
self.write(GLOBAL_META_FILE_NAME, global_meta)
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
checkpoint_name = base_name
if self.is_distributed:
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
if shard_idx is not None:
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
return checkpoint_name
def save_model(self, model_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save model after saving meta'
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
self.write(name, model_checkpoint)
self.model_checkpoint_names.append(name)
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
self.write(name, optimizer_checkpoint)
self.optimizer_checkpoint_names.append(name)
def save_meta(self, meta_checkpoint: dict) -> None:
if len(self.model_checkpoint_names) > 0:
meta_checkpoint['model'] = self.model_checkpoint_names
if len(self.optimizer_checkpoint_names) > 0:
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
self.is_meta_saved = True
def save_others(self, kwargs: dict) -> None:
if self.is_main_process:
self.write(OTHER_CKPT_FILE_NAME, kwargs)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import functools
import os
import random
import socket
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Union, Dict, Optional
import functools
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from torch._six import inf
from torch.nn.parameter import Parameter
try:
import colossal_C
except:
pass
from contextlib import contextmanager
import torch.distributed as dist
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.tensor import ColoParameter, ProcessGroup
from .multi_tensor_apply import multi_tensor_applier
from colossalai.tensor import ColoParameter, ProcessGroup
from collections import defaultdict
try:
from colossalai._C import fused_optim
except:
fused_optim = None
def print_rank_0(msg: str, logger=None):
......@@ -128,11 +127,18 @@ def is_model_parallel_parameter(p):
def _calc_l2_norm(grads):
# we should not
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
colossal_C.multi_tensor_l2norm,
fused_optim.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False # no per-parameter norm
......@@ -269,7 +275,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None:
cpu_grads.append(p.grad.detach())
if len(cuda_grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], clip_coef)
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads],
clip_coef)
for g in cpu_grads:
g.mul_(clip_coef)
......@@ -395,7 +402,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if enable_cuda_kernels:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
else:
for p in params:
p.grad.detach().mul_(clip_coeff)
......
from .utils import InsertPostInitMethodToModuleSubClasses
from typing import Any, Dict, Iterator, Optional, Tuple, Union
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
from torch import nn
from typing import Iterator, Tuple, Union
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from .utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica
......@@ -23,6 +26,39 @@ def _named_params_with_replica(
yield name, val
def _convert_to_coloparam(param: torch.nn.Parameter,
device: torch.device,
dtype=torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None) -> ColoParameter:
if isinstance(param, ColoParameter):
return param
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# param is the global tensor.
if param.device.type == "meta":
colo_param = ColoParameter(param, requires_grad=requires_grad)
else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
# the param that can not be sharded by the default plan
if default_pg is not None:
colo_param.set_process_group(default_pg)
if default_dist_spec is not None:
try:
colo_param.set_dist_spec(default_dist_spec)
except:
pass
return colo_param
def ColoModulize(module):
"""
Replacing the parameters() and named_parameters() with our customized ones
......@@ -34,20 +70,24 @@ def ColoModulize(module):
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self,
lazy_memory_allocate: bool = False,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float):
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None):
"""
Args:
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False.
device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu').
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
dtype (torch.dtype): the dtype of parameters initialized. Defults to torch.float.
default_pg (ProcessGroup): the default process group for all initialized parameters.
default_dist_spec: the default distributed specifications.
"""
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device
self._dtype = dtype
self._register_colo_modules()
self._default_pg = default_pg
self._default_dist_spec = default_dist_spec
def _register_colo_modules(self):
register_colo_module(torch.nn.Linear, ColoLinear())
......@@ -61,10 +101,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
if hasattr(module, '_colo_visited'):
return
name_list = []
for name, param in _named_params_with_replica(module):
if isinstance(param, ColoTensor):
......@@ -87,17 +123,74 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
if param in replaced_tensors:
colo_param = replaced_tensors[param]
else:
save_torch_payload = True if not self._lazy_memory_allocate else False
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# TODO(jiaruifang) we initialize a Default PG memory
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
requires_grad=requires_grad)
# add mapping record
colo_param = _convert_to_coloparam(param, self._device, self._dtype, self._default_pg,
self._default_dist_spec)
replaced_tensors[param] = colo_param
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
meta_param_flag = 0
meta_buffer_flag = 0
for param in module.parameters():
if param.device.type=="meta":
meta_param_flag = 1
if meta_param_flag == 1 and param.device.type!="meta":
raise ValueError("Meta parameters and valued parameters can not be in the same model")
for buffer in module.buffers():
if buffer.device.type=="meta":
meta_buffer_flag = 1
if meta_buffer_flag == 1 and buffer.device.type!="meta":
raise ValueError("Meta buffers and valued buffers can not be in the same model")
if meta_param_flag==1 and meta_buffer_flag==1:
pass
elif meta_buffer_flag==0 and meta_param_flag==1:
for name, buf in module.named_buffers():
module._buffers[name] = module._buffers[name].to(device=self._device)
elif meta_param_flag==0 and meta_buffer_flag==1:
for name, param in module.named_parameters():
module._parameters[name] = module._parameters[name].to(device=self._device)
else:
module.to(self._device)
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),
dtype: torch.dtype = torch.float,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec=None):
"""post_process_colo_init_ctx
This function is called after `ColoInitContext`.
Args:
model (torch.nn.module): the model
device (torch.device, optional): device type of the model params. Defaults to torch.device('cpu').
dtype (torch.dtype, optional): dtype of the model params. Defaults to torch.float.
default_pg (Optional[ProcessGroup], optional): default process group. Defaults to None. Inidicates a DP-only process group.
default_dist_spec (Any, optional): default dist spec of params. Defaults to None.
Raises:
RuntimeError: raise error if
"""
module.to(self._device)
ColoModulize(module)
torch_params = []
for n, p in model.named_parameters():
if not isinstance(p, ColoParameter):
# print(f"{n} is not a ColoParameter. We are going to converting it to ColoParameter")
torch_params.append((n, p))
for (n, param) in torch_params:
name_list = n.split('.')
module = model
for i in range(len(name_list) - 1):
module = module._modules[name_list[i]]
delattr(module, name_list[-1])
setattr(module, name_list[-1], _convert_to_coloparam(param, device, dtype, default_pg, default_dist_spec))
del torch_params
for n, p in model.named_parameters():
if not isinstance(p, ColoTensor):
raise RuntimeError
#!/usr/bin/env python
# coding: utf-8
import inspect
import types
from typing import Callable, List
import torch
import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
import types
import inspect
from typing import List, Callable
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.utils.model.utils import substitute_init_recursively
class LazyInitContext():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
This API is only experimental and subject to future changes.
Usage:
with LazyInitContext() as ctx:
......@@ -30,19 +31,20 @@ class LazyInitContext():
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
......@@ -212,18 +214,19 @@ class LazyInitContext():
materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache
tensor = materialized_tensor
else:
materialized_tensor = tensor
# apply init function
if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(tensor, *args, **kwargs)
init_func(materialized_tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter
if is_param:
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
else:
tensor = ColoTensor.from_torch_tensor(tensor)
tensor = ColoTensor.from_torch_tensor(materialized_tensor)
# override the original tensor
with torch.no_grad():
......
......@@ -14,7 +14,6 @@ class MultiTensorApply(object):
def __init__(self, chunk_size):
try:
import colossal_C
MultiTensorApply.available = True
self.chunk_size = chunk_size
except ImportError as err:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment