Unverified Commit 27061426 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
parent 285fe7ba
import gc import gc
import logging import logging
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC ...@@ -16,7 +14,6 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralC
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_model_base_filenames, get_model_base_filenames,
get_optimizer_base_filenames, get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict, load_shard_state_dict,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
...@@ -24,8 +21,7 @@ from colossalai.checkpoint_io.utils import ( ...@@ -24,8 +21,7 @@ from colossalai.checkpoint_io.utils import (
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
...@@ -132,11 +128,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -132,11 +128,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
As there is communication when getting state dict, this must be called on all processes. As there is communication when getting state dict, this must be called on all processes.
""" """
# If optimizer is wrapped, unwrap it. assert isinstance(optimizer, GeminiOptimizer)
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
...@@ -183,11 +175,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -183,11 +175,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if not os.path.isfile(checkpoint_index_file): if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file") logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
# If optimizer is wrapped, unwrap it. assert isinstance(optimizer, GeminiOptimizer)
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()
assert isinstance(optimizer, ZeroOptimizer)
# Read checkpoint index file. # Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
...@@ -220,47 +208,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO): ...@@ -220,47 +208,6 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
class GeminiModel(ModelWrapper):
def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)
def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module
class GeminiOptimizer(OptimizerWrapper):
def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
class GeminiPlugin(DPPluginBase): class GeminiPlugin(DPPluginBase):
""" """
Plugin for Gemini. Plugin for Gemini.
...@@ -277,8 +224,20 @@ class GeminiPlugin(DPPluginBase): ...@@ -277,8 +224,20 @@ class GeminiPlugin(DPPluginBase):
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args: Args:
device (torch.device): device to place the model. chunk_config_dict (dict, optional): chunk configuration dictionary.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False. pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
...@@ -310,8 +269,14 @@ class GeminiPlugin(DPPluginBase): ...@@ -310,8 +269,14 @@ class GeminiPlugin(DPPluginBase):
def __init__( def __init__(
self, self,
device: Optional[torch.device] = None, chunk_config_dict: Optional[dict] = None,
placement_policy: str = "cpu", chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16", precision: str = "fp16",
pin_memory: bool = False, pin_memory: bool = False,
force_outputs_fp32: bool = False, force_outputs_fp32: bool = False,
...@@ -335,8 +300,14 @@ class GeminiPlugin(DPPluginBase): ...@@ -335,8 +300,14 @@ class GeminiPlugin(DPPluginBase):
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict( self.gemini_config = dict(
device=(device or get_current_device()), chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy, placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory, pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32, force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode, strict_ddp_mode=strict_ddp_mode,
...@@ -393,12 +364,15 @@ class GeminiPlugin(DPPluginBase): ...@@ -393,12 +364,15 @@ class GeminiPlugin(DPPluginBase):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose) model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, optimizer = GeminiOptimizer(optimizer,
self.verbose) model.unwrap(),
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
......
...@@ -3,9 +3,15 @@ from typing import Optional ...@@ -3,9 +3,15 @@ from typing import Optional
import torch import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .colo_tensor import _convert_output
WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}
def is_no_hook_op(func) -> bool:
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS
def filter_colo_parameters(*args, **kwargs): def filter_colo_parameters(*args, **kwargs):
...@@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): ...@@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
""" """
def __new__(cls, def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad) return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@property
def shared_param_modules(self):
return self._shared_param_modules
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
def __repr__(self):
return super(ColoParameter, self).__repr__()
@classmethod @classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None): def __torch_function__(cls, func, types, args=..., kwargs=None):
if ColoParamOpHookManager.has_hook(): if kwargs is None:
if not func.__name__.startswith('__'): kwargs = {}
if kwargs is None: if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
kwargs = {} params = filter_colo_parameters(*args, **kwargs)
params = filter_colo_parameters(*args, **kwargs) if len(params) > 0:
if len(params) > 0: with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunction(): new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) args, kwargs = replace_args(args, kwargs, new_args)
args, kwargs = replace_args(args, kwargs, new_args) ret = super().__torch_function__(func, types, args, kwargs)
ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction():
with torch._C.DisableTorchFunction(): ret = ColoParamOpHookManager.post_op(params, ret)
ret = ColoParamOpHookManager.post_op(params, ret) return _convert_output(ret, func)
return ret
return super().__torch_function__(func, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs)
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
...@@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): ...@@ -96,9 +74,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else: else:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
data = self.data.clone() data = self.data.clone()
tensor = ColoParameter(data, tensor = ColoParameter(data, self.requires_grad)
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
memo[id(self)] = tensor memo[id(self)] = tensor
return tensor return tensor
......
import operator from functools import lru_cache
from copy import copy from typing import Callable, Set
from functools import lru_cache, reduce
from typing import Callable, Optional, Set
import torch import torch
from colossalai.tensor.dist_spec_mgr import DistSpecManager INPALCE_MAPPING = {
from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec torch.Tensor.add_: torch.Tensor.add,
from colossalai.tensor.process_group import ProcessGroup torch.Tensor.sub_: torch.Tensor.sub,
from colossalai.tensor.tensor_spec import ColoTensorSpec torch.Tensor.mul_: torch.Tensor.mul,
torch.Tensor.div_: torch.Tensor.div
from .const import TensorType }
from .op_wrapper import _COLOSSAL_OPS
@lru_cache(None) @lru_cache(None)
...@@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]: ...@@ -25,61 +22,37 @@ def _get_my_nowrap_functions() -> Set[Callable]:
} }
def _convert_output(output, colo_spec: ColoTensorSpec): def _convert(output):
if type(output) == torch.Tensor: if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
return ColoTensor.from_torch_tensor(output, colo_spec) output.__class__ = ColoTensor
elif isinstance(output, (list, tuple)): elif isinstance(output, (list, tuple)):
return type(output)(_convert_output(o, colo_spec) for o in output) output = type(output)(_convert(o) for o in output)
else: return output
return output
def _get_spec_from_args(args, kwargs) -> ColoTensorSpec: def _convert_output(output, func):
for elem in args: if func in _get_my_nowrap_functions():
if isinstance(elem, ColoTensor): return output
pg = elem.get_process_group() return _convert(output)
dp = elem.dist_spec
return ColoTensorSpec(pg, dp)
elif isinstance(elem, (list, tuple)):
spec = _get_spec_from_args(elem, {})
if spec is not None:
return spec
for k, v in kwargs.items():
if isinstance(v, ColoTensor):
pg = v.get_process_group()
dp = v.dist_spec
return ColoTensorSpec(pg, dp)
return None
class ColoTensor(torch.Tensor): class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
The Colotensor can be initialized with a PyTorch tensor in the following ways. It is only used to trigger the torch function hook.
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, ReplicaSpec()))
>>> # The tensor passed in is a tensor after sharding but not a global tensor.
>>> shard_spec = ShardSpec(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(ReplicaSpec()).
""" """
torch_major = int(torch.__version__.split('.')[0]) torch_major = int(torch.__version__.split('.')[0])
torch_minor = int(torch.__version__.split('.')[1]) torch_minor = int(torch.__version__.split('.')[1])
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor': def __new__(cls, data: torch.Tensor) -> 'ColoTensor':
""" """
The signature of the __new__ has to be consistent with the torch.Tensor. The signature of the __new__ has to be consistent with the torch.Tensor.
Args: Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor. data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization.
Returns: Returns:
ColoTensor: a ColoTensor wrappers the data. ColoTensor: a ColoTensor wrappers the data.
...@@ -88,86 +61,6 @@ class ColoTensor(torch.Tensor): ...@@ -88,86 +61,6 @@ class ColoTensor(torch.Tensor):
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad) return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if spec is None:
self.has_initialized = False
self.dist_spec = ReplicaSpec()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
if spec.pg is None:
self.process_group = ProcessGroup()
else:
self.process_group = spec.pg
self._type = TensorType.NONMODEL
def has_compute_spec(self) -> bool:
return self.compute_spec is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args:
pg (ProcessGroup): target pg
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg
def get_tp_world_size(self) -> int:
return self.process_group.tp_world_size()
def get_dp_world_size(self) -> int:
"""get_dp_world_size
get the dp world size of the tensor.
Returns:
int: dp world size
"""
return self.process_group.dp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None
self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec is not None:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec is not None:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):
return self.compute_spec.compute_pattern == compute_pattern
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
...@@ -175,9 +68,6 @@ class ColoTensor(torch.Tensor): ...@@ -175,9 +68,6 @@ class ColoTensor(torch.Tensor):
if not all(issubclass(cls, t) for t in types): if not all(issubclass(cls, t) for t in types):
return NotImplemented return NotImplemented
global _COLOSSAL_OPS
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module # in order to trigger pre-op hook in the forward of checkpoint module
...@@ -189,94 +79,16 @@ class ColoTensor(torch.Tensor): ...@@ -189,94 +79,16 @@ class ColoTensor(torch.Tensor):
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs) return backward_tensor.backward(**tensor_kwargs)
# replace the in-place function
if func in INPALCE_MAPPING:
func = INPALCE_MAPPING[func]
# set the 'inplace' kwargs to False
if 'inplace' in kwargs:
kwargs['inplace'] = False
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
if func in _get_my_nowrap_functions(): return _convert_output(ret, func)
return ret
else:
colo_spec = _get_spec_from_args(args, kwargs)
return _convert_output(ret, colo_spec)
def __repr__(self):
output_list = [super(ColoTensor, self).__repr__()]
output_list.append(str(self.process_group))
output_list.append(str(self.dist_spec))
if self.compute_spec is not None:
output_list.append(str(self.compute_spec))
return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
dist_spec (_DistSpec): the target dist. spec.
"""
assert self.grad_fn is None, "Current tensor has grad_fn and it can't get converted"
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def redistribute(self, dist_spec: _DistSpec, pg: Optional[ProcessGroup] = None) -> 'ColoTensor':
"""redistribute
Redistribute the tensor among processes. The rule is like this:
1. If the pg is None, then redistribute the tensor payload among the TP process group. Keep the
DP process group not changed.
2. If the pg is not not None and not equal to the current process group.
First, convert the tensor as replicated among the TP process group.
Second, reset the process group to the new pg.
Third, convert the tensor (new replicated both among the tp process group) to the new dist_spec.
Args:
dist_spec (_DistSpec): the new dist spec.
pg (Optional[ProcessGroup], optional): the new process group . Defaults to None.
Returns:
ColoTensor: a redistributed colotensor
"""
if pg is not None and pg != self.get_process_group():
# if the pg is not equal, convert the current tensor to replicated
handled = self.redistribute(ReplicaSpec())
else:
handled = self
pg = self.process_group
ret = DistSpecManager.handle_trans_spec(handled, handled.dist_spec, dist_spec, pg)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(pg=pg, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._redistribute(dist_spec=ReplicaSpec())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to ReplicaSpec()
"""
return self.redistribute(ReplicaSpec())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
"""from_torch_tensor
A static method builds a `ColoTensor` from a PyTorch Tensor.
Args:
tensor (torch.Tensor): the pytorch tensor, which is a local tensor for this rank not a global tensor.
spec (Optional[ColoTensorSpec], optional): tensor spec. Defaults to None.
Returns:
ColoTensor: a ColoTensor
"""
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
if id(self) in memo: if id(self) in memo:
...@@ -284,60 +96,6 @@ class ColoTensor(torch.Tensor): ...@@ -284,60 +96,6 @@ class ColoTensor(torch.Tensor):
else: else:
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
data = self.data.clone() data = self.data.clone()
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec))) tensor = ColoTensor(data)
memo[id(self)] = tensor memo[id(self)] = tensor
return tensor return tensor
# override builtin functions which must use tensor in replicate placement #
def size_local(self, *args) -> torch.Size:
with torch._C.DisableTorchFunction():
return super().size(*args)
def size_global(self, *args) -> torch.Size:
"""size_global
override the torch building size()
the shape passed in must be in a replicate placement.
Returns:
torch.Size: the global tensor shape
"""
if self.is_replicate():
return self.size_local(*args)
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list = list(self.size_local())
for dim, num_partition in zip(dims, num_partitions):
size_list[dim] *= num_partition
if args == ():
return torch.Size(size_list)
else:
return size_list[args[0]]
def numel_global(self):
"""Returns the number of elements in the tensor when it's replicated.
"""
return reduce(operator.mul, self.size_global(), 1)
# Some API for dist spec check
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD
...@@ -3,9 +3,7 @@ from contextlib import contextmanager ...@@ -3,9 +3,7 @@ from contextlib import contextmanager
from typing import Any, List, Tuple from typing import Any, List, Tuple
import torch import torch
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.tensor_spec import ColoTensorSpec
class ColoParamOpHook(ABC): class ColoParamOpHook(ABC):
...@@ -82,26 +80,18 @@ class ColoParamOpHookManager: ...@@ -82,26 +80,18 @@ class ColoParamOpHookManager:
@staticmethod @staticmethod
def pre_op(params: List[torch.Tensor], *args: Any) -> list: def pre_op(params: List[torch.Tensor], *args: Any) -> list:
ColoParamOpHookManager._trigger_pre_forward(params) ColoParamOpHookManager._trigger_pre_forward(params)
grad_args, rear_args = _get_grad_args(*args) # auto grad function can only recognize torch.Tensor, thus we have to flatten the input
colo_info = _get_colo_tensors_info(*grad_args) # if one of the input requires grad, all the output will be treated as requires grad
rets = PreFwdPostBwd.apply(params, *grad_args) # and will have grad fn even the corresponding input does not require grad
update_args = _update_colo_tensors(colo_info, *rets) # we have to extract tensors requiring grad into flat list and then merge them back
if rear_args is None: grad_args, other_args, grad_flags, spec = _flatten_grad_args(args)
return update_args new_grad_args = PreFwdPostBwd.apply(params, *grad_args)
else: return _merge_args(new_grad_args, other_args, grad_flags, spec)
arg_zero = (tuple(update_args),)
return arg_zero + rear_args
@staticmethod @staticmethod
def post_op(params: List[torch.Tensor], arg: Any) -> Any: def post_op(params: List[torch.Tensor], arg: Any) -> Any:
ColoParamOpHookManager._trigger_post_forward(params) ColoParamOpHookManager._trigger_post_forward(params)
colo_info = _get_colo_tensors_info(arg) return PostFwdPreBwd.apply(params, arg)
ret = PostFwdPreBwd.apply(params, arg)
res = _update_colo_tensors(colo_info, ret)
if len(res) == 1:
return res[0]
else:
return res
@staticmethod @staticmethod
def has_hook() -> bool: def has_hook() -> bool:
...@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool: ...@@ -141,57 +131,24 @@ def _is_grad_tensor(obj) -> bool:
return False return False
def _has_grad_tensor(obj) -> bool: def _flatten_grad_args(args) -> Tuple[list, list, List[bool], TreeSpec]:
if isinstance(obj, tuple) or isinstance(obj, list): flat_args, spec = tree_flatten(args)
for x in obj: grad_args = []
if _has_grad_tensor(x): other_args = []
return True grad_flags = []
return False for arg in flat_args:
elif isinstance(obj, dict): flag = _is_grad_tensor(arg)
for x in obj.values(): grad_flags.append(flag)
if _has_grad_tensor(x): if flag:
return True grad_args.append(arg)
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first argument should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
check_grad_flag = False
for obj in arg_zero:
check_grad_flag |= _is_grad_tensor(obj)
if not check_grad_flag:
raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.")
return arg_zero, args[1:]
def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
else: else:
info.append(None) other_args.append(arg)
return info assert len(grad_args) > 0
return grad_args, other_args, grad_flags, spec
def _update_colo_tensors(info, *args) -> list:
ret = [] def _merge_args(grad_args, other_args, grad_flags, spec):
for t_info, arg in zip(info, args): grad_iter = iter(grad_args)
if t_info is not None: other_iter = iter(other_args)
t_cls, spec = t_info flat_args = [next(grad_iter) if flag else next(other_iter) for flag in grad_flags]
arg = t_cls.from_torch_tensor(arg, spec=spec) return tree_unflatten(flat_args, spec)
ret.append(arg)
return ret
...@@ -2,8 +2,7 @@ from .gemini import ( ...@@ -2,8 +2,7 @@ from .gemini import (
ColoInitContext, ColoInitContext,
GeminiAdamOptimizer, GeminiAdamOptimizer,
GeminiDDP, GeminiDDP,
ZeroDDP, GeminiOptimizer,
ZeroOptimizer,
get_static_torch_model, get_static_torch_model,
post_process_colo_init_ctx, post_process_colo_init_ctx,
) )
...@@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer ...@@ -11,6 +10,6 @@ from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper from .wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = [ __all__ = [
'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', 'GeminiDDP', 'GeminiOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
] ]
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP from .gemini_ddp import GeminiDDP
from .gemini_mgr import GeminiManager from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
from .utils import get_static_torch_model from .utils import get_static_torch_model
__all__ = [ __all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' 'get_static_torch_model', 'GeminiAdamOptimizer', 'GeminiOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
] ]
...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional ...@@ -4,8 +4,8 @@ from typing import Dict, List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
...@@ -55,7 +55,7 @@ class Chunk: ...@@ -55,7 +55,7 @@ class Chunk:
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
process_group: ColoProcessGroup, process_group: ProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False, cpu_shard_init: bool = False,
...@@ -69,7 +69,7 @@ class Chunk: ...@@ -69,7 +69,7 @@ class Chunk:
Args: Args:
chunk_size (int): the number of elements in the chunk chunk_size (int): the number of elements in the chunk
process_group (ColoProcessGroup): the process group of this chunk process_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored. init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU The default value is None, which is the current GPU
...@@ -83,7 +83,7 @@ class Chunk: ...@@ -83,7 +83,7 @@ class Chunk:
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.utilized_size = 0 self.utilized_size = 0
self.torch_pg = process_group.dp_process_group() self.torch_pg = process_group
self.pg_size = dist.get_world_size(self.torch_pg) self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg) self.pg_rank = dist.get_rank(self.torch_pg)
...@@ -218,7 +218,7 @@ class Chunk: ...@@ -218,7 +218,7 @@ class Chunk:
return False return False
else: else:
return self.tensor_state_cnter[TensorState.HOLD] + \ return self.tensor_state_cnter[TensorState.HOLD] + \
self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors self.tensor_state_cnter[TensorState.HOLD_AFTER_BWD] == self.num_tensors
@property @property
def can_reduce(self): def can_reduce(self):
......
...@@ -2,8 +2,9 @@ from collections import deque ...@@ -2,8 +2,9 @@ from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState from .chunk import Chunk, ChunkFullError, TensorState
...@@ -27,16 +28,17 @@ class ChunkManager: ...@@ -27,16 +28,17 @@ class ChunkManager:
self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size') self.dp_degree_chunk_size_dict[k] = v.pop('chunk_size')
v['init_device'] = self.device v['init_device'] = self.device
self.chunk_groups: Dict[str, Deque] = dict() self.chunk_groups: Dict[str, Deque[Chunk]] = dict()
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict() self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = dict()
self.accessed_chunks: Set[Chunk] = set() self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0 self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}
def register_tensor(self, def register_tensor(self,
tensor: ColoTensor, tensor: torch.Tensor,
group_type: str, group_type: str,
config_key: int, config_key: int,
process_group: ProcessGroup,
cpu_offload: bool = False, cpu_offload: bool = False,
pin_memory: bool = False) -> None: pin_memory: bool = False) -> None:
""" """
...@@ -51,7 +53,7 @@ class ChunkManager: ...@@ -51,7 +53,7 @@ class ChunkManager:
pin_memory: whether the chunk is pinned in the cpu memory pin_memory: whether the chunk is pinned in the cpu memory
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
assert isinstance(tensor, ColoTensor), "Please feed ColoTensor to this ChunkManager" assert isinstance(tensor, torch.Tensor), "Please feed Tensor to this ChunkManager"
assert config_key in self.dp_degree_chunk_size_dict assert config_key in self.dp_degree_chunk_size_dict
chunk_size = self.dp_degree_chunk_size_dict[config_key] chunk_size = self.dp_degree_chunk_size_dict[config_key]
...@@ -73,12 +75,12 @@ class ChunkManager: ...@@ -73,12 +75,12 @@ class ChunkManager:
if tensor.numel() > chunk_size: if tensor.numel() > chunk_size:
chunk_size = tensor.numel() chunk_size = tensor.numel()
dp_size = tensor.get_dp_world_size() dp_size = dist.get_world_size(process_group)
chunk_size = chunk_size + (-chunk_size % dp_size) chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk( chunk = Chunk(
chunk_size=chunk_size, chunk_size=chunk_size,
process_group=tensor.process_group, process_group=process_group,
dtype=tensor.dtype, dtype=tensor.dtype,
cpu_shard_init=cpu_offload, cpu_shard_init=cpu_offload,
pin_memory=pin_memory, pin_memory=pin_memory,
...@@ -220,7 +222,7 @@ class ChunkManager: ...@@ -220,7 +222,7 @@ class ChunkManager:
msg.append(f'[{i}] {chunk}\n') msg.append(f'[{i}] {chunk}\n')
return ''.join(msg) return ''.join(msg)
def __get_chunk_group(self, group_name: str) -> Deque: def __get_chunk_group(self, group_name: str) -> Deque[Chunk]:
"""Register a chunk group. """Register a chunk group.
""" """
if group_name not in self.chunk_groups: if group_name not in self.chunk_groups:
......
...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
...@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: ...@@ -59,7 +60,7 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc return left + acc
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: def _tensor_numel(local_param: ColoParameter) -> int:
"""_tensor_numel """_tensor_numel
Get the number of elements of a tensor. Get the number of elements of a tensor.
...@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: ...@@ -71,15 +72,12 @@ def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
Returns: Returns:
int: the number of elements. int: the number of elements.
""" """
if strict_ddp_flag and type(local_param) is ColoParameter: # TODO(ver217): support dtensor here
return local_param.numel_global() return local_param.numel()
else:
# if local_param is not ColoParameter, we assume it's replicated
return local_param.numel()
def classify_params_by_dp_degree(param_order: OrderedParamGenerator, def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]: process_group: ProcessGroup) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree """classify_params_by_dp_degree
Classify the parameters by their dp degree Classify the parameters by their dp degree
...@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, ...@@ -97,13 +95,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
# assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" # assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if is_ddp_ignored(param): if is_ddp_ignored(param):
continue continue
param_key = dist.get_world_size(process_group)
if strict_ddp_flag or type(param) is not ColoParameter:
# if model is not initialized with ColoInitContext, we assume it's replicated
# TODO(ver217): integrate DTensor
param_key = dist.get_world_size()
else:
param_key = param.process_group.dp_world_size()
if param_key not in params_dict: if param_key not in params_dict:
params_dict[param_key] = [] params_dict[param_key] = []
...@@ -119,6 +111,7 @@ def search_chunk_configuration( ...@@ -119,6 +111,7 @@ def search_chunk_configuration(
min_chunk_size_m: float = 32, min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True, filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False, strict_ddp_flag: bool = False,
process_group: Optional[ProcessGroup] = None,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration """search_chunk_configuration
...@@ -149,7 +142,7 @@ def search_chunk_configuration( ...@@ -149,7 +142,7 @@ def search_chunk_configuration(
min_chunk_size = round(min_chunk_size_m * 1024**2) min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0 assert search_range >= 0
params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) params_dict = classify_params_by_dp_degree(param_order, process_group)
size_lcm = np.lcm.reduce(list(params_dict.keys())) size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict() config_dict: Dict[int, Dict] = dict()
total_param_size = 0 total_param_size = 0
...@@ -157,7 +150,7 @@ def search_chunk_configuration( ...@@ -157,7 +150,7 @@ def search_chunk_configuration(
size_dict: Dict[int, List[int]] = dict() size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict: for dp_degree in params_dict:
params_list = params_dict[dp_degree] params_list = params_dict[dp_degree]
size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list] size_list = [_tensor_numel(p) for p in params_list]
group_acc_size = sum(size_list) group_acc_size = sum(size_list)
total_param_size += group_acc_size total_param_size += group_acc_size
......
...@@ -2,19 +2,20 @@ import itertools ...@@ -2,19 +2,20 @@ import itertools
from collections import OrderedDict from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union from typing import Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored from colossalai.utils import get_current_device, is_ddp_ignored
...@@ -30,14 +31,13 @@ except ImportError: ...@@ -30,14 +31,13 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
__all__ = [ __all__ = [
'ZeroDDP',
'GeminiDDP', 'GeminiDDP',
] ]
class ZeroDDP(ColoDDP): class GeminiDDP(ModelWrapper):
"""ZeRO DDP for ColoTensor. """ZeRO DDP.
Warning: Nested ZeroDDP is not supported now. Warning: Nested GeminiDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager. It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
...@@ -54,20 +54,54 @@ class ZeroDDP(ColoDDP): ...@@ -54,20 +54,54 @@ class ZeroDDP(ColoDDP):
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16. mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
""" """
def __init__(self, def __init__(
module: torch.nn.Module, self,
gemini_manager: GeminiManager, module: torch.nn.Module,
pin_memory: bool = False, chunk_config_dict: Optional[dict] = None,
force_outputs_fp32: bool = False, chunk_init_device: torch.device = torch.device('cpu'),
strict_ddp_mode: bool = False, placement_policy: str = "static",
scatter_after_inference: bool = True, shard_param_frac: float = 1.0, # only for static placement
mixed_precision: torch.dtype = torch.float16) -> None: offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
process_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
verbose: bool = False) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16) assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager if chunk_config_dict is not None:
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
self.chunk_manager = init_chunk_manager(model=module,
init_device=chunk_init_device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
self.gemini_manager = GeminiManager(placement_policy,
self.chunk_manager,
memstats,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager) self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[ColoTensor] = list() self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list() self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0 self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.grads_device: Dict[torch.Tensor, torch.device] = dict()
...@@ -75,6 +109,7 @@ class ZeroDDP(ColoDDP): ...@@ -75,6 +109,7 @@ class ZeroDDP(ColoDDP):
self.name2param: Dict[str, nn.Parameter] = dict() self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.dp_process_group = process_group or _get_default_group()
self._logger = get_dist_logger() self._logger = get_dist_logger()
...@@ -88,20 +123,67 @@ class ZeroDDP(ColoDDP): ...@@ -88,20 +123,67 @@ class ZeroDDP(ColoDDP):
for p in module.parameters(): for p in module.parameters():
param_order.append(p) param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters(): for name, param in module.named_parameters():
self.param2name[param] = name self.param2name[param] = name
for m_name, m_var in module.named_modules(): for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False): for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers() self._cast_buffers()
# register grad hook
for p in module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.module.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.module.named_buffers(prefix, recurse)
def named_children(self):
return self.module.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.module.named_modules(memo, prefix, remove_duplicate)
@staticmethod
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing ColoDDP.
Example:
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
>>> module = ColoDDP(module)
Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True
def unwrap(self):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set(self, def _get_non_persistent_buffers_set(self,
module, module,
...@@ -207,7 +289,7 @@ class ZeroDDP(ColoDDP): ...@@ -207,7 +289,7 @@ class ZeroDDP(ColoDDP):
error_params.append(self.param2name[param]) error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params) error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n", "The most possible reason is that the model is not compatible with GeminiDDP.\n",
f"{error_str}") f"{error_str}")
self._setup_grads_ptr() self._setup_grads_ptr()
self._logger.debug( self._logger.debug(
...@@ -227,6 +309,7 @@ class ZeroDDP(ColoDDP): ...@@ -227,6 +309,7 @@ class ZeroDDP(ColoDDP):
self._post_backward() self._post_backward()
def grad_handle(self, p, grad): def grad_handle(self, p, grad):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad) empty_grad = torch.empty_like(grad)
free_storage(empty_grad) free_storage(empty_grad)
with torch._C.DisableTorchFunction(): with torch._C.DisableTorchFunction():
...@@ -533,7 +616,7 @@ class ZeroDDP(ColoDDP): ...@@ -533,7 +616,7 @@ class ZeroDDP(ColoDDP):
for chunk_32 in chunk_list: for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None assert chunk_16 is not None
chunk_16.optim_update() chunk_16.payload.copy_(chunk_32.payload)
for name, buf in persistent_buffers.items(): for name, buf in persistent_buffers.items():
if buf is not None: if buf is not None:
...@@ -557,17 +640,11 @@ class ZeroDDP(ColoDDP): ...@@ -557,17 +640,11 @@ class ZeroDDP(ColoDDP):
unexpected_keys.append(key) unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup() dp_world_size = dist.get_world_size(self.dp_process_group)
for p in param_order.generate(): for p in param_order.generate():
self._preprocess_param(p) self._preprocess_param(p)
assert type(p) is ColoParameter assert type(p) is ColoParameter
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient # ignore the parameters with no gradient
if not p.requires_grad: if not p.requires_grad:
self.set_params_to_ignore([p]) self.set_params_to_ignore([p])
...@@ -578,38 +655,37 @@ class ZeroDDP(ColoDDP): ...@@ -578,38 +655,37 @@ class ZeroDDP(ColoDDP):
continue continue
# create a fp32 parameter # create a fp32 parameter
fp32_data = p.data.float() fp32_p = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter # create a fp16 parameter
p.data = p.data.to(self.mixed_precision) p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager # register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p, self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param', group_type='fp16_param',
config_key=dp_world_size, config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory) pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p, self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param', group_type='fp32_param',
config_key=dp_world_size, config_key=dp_world_size,
process_group=self.dp_process_group,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
pin_memory=pin_memory) pin_memory=pin_memory)
self.fp16_params.append(p) self.fp16_params.append(p)
self.fp32_params.append(fp32_p) self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups() self.chunk_manager.close_all_groups()
self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
# move master weights to corresponding device and setup paired chunks
for p, fp32_p in zip(self.fp16_params, self.fp32_params): for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p) chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16) chunk_32.init_pair(chunk_16)
if chunk_32.device_type != self.grads_device[p].type:
# keep gathered chunks are in CUDA self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self): def _cast_buffers(self):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
...@@ -727,67 +803,3 @@ class _StateDictSharder: ...@@ -727,67 +803,3 @@ class _StateDictSharder:
self.current_block[name] = tensor self.current_block[name] = tensor
self.current_block_size += tensor_size self.current_block_size += tensor_size
return ret_block, ret_block_size return ret_block, ret_block_size
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
search_range_m = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)
import functools import functools
from time import time from time import time
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -26,7 +26,11 @@ class GeminiManager: ...@@ -26,7 +26,11 @@ class GeminiManager:
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
""" """
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager,
memstats: Optional[MemStats] = None,
**placement_kwargs) -> None:
assert placement_policy in PlacementPolicyFactory.get_policy_names() assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy self.policy_name = placement_policy
...@@ -37,7 +41,7 @@ class GeminiManager: ...@@ -37,7 +41,7 @@ class GeminiManager:
self._memstats = memstats self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None self._memstats) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
...@@ -133,10 +137,6 @@ class GeminiManager: ...@@ -133,10 +137,6 @@ class GeminiManager:
if self._warmup and self._placement_policy.need_mem_stats: if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks) self._compute_list.append(chunks)
@property
def default_device(self):
return self._placement_policy.get_default_device()
def sample_overall_data(self): def sample_overall_data(self):
if self._mem_stats_collector: if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data() self._mem_stats_collector.sample_overall_data()
...@@ -159,6 +159,6 @@ class GeminiManager: ...@@ -159,6 +159,6 @@ class GeminiManager:
def is_cuda_margin_mem_avail(self) -> bool: def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats return self._placement_policy.need_mem_stats
@staticmethod def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
def get_default_device(policy_name: str) -> torch.device: torch.device]) -> None:
return PlacementPolicyFactory.get_default_device(policy_name) self._placement_policy.setup_grads_device(params, grads_device_map)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -11,15 +11,16 @@ from torch.optim import Optimizer ...@@ -11,15 +11,16 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .gemini_ddp import ZeroDDP from .gemini_ddp import GeminiDDP
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] __all__ = ['GeminiOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
...@@ -27,7 +28,7 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} ...@@ -27,7 +28,7 @@ _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self, def __init__(self,
module: ZeroDDP, module: GeminiDDP,
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
...@@ -46,11 +47,11 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): ...@@ -46,11 +47,11 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
self.module.overflow_counter = 0 self.module.overflow_counter = 0
class ZeroOptimizer(ColossalaiOptimizer): class GeminiOptimizer(OptimizerWrapper):
"""A wrapper for optimizer. ``ZeroDDP`` and ``ZeroOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3). """A wrapper for optimizer. ``GeminiDDP`` and ``GeminiOptimizer`` implement Zero Redundancy Optimizer (ZeRO state-3).
Note: Note:
You must use ``ZeroDDP`` with ``ZeroOptimizer``. You must use ``GeminiDDP`` with ``GeminiOptimizer``.
Note: Note:
Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`, Make sure you set ``placement_policy`` of ``GeminiManager`` to `"auto"`,
...@@ -58,7 +59,7 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -58,7 +59,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
Args: Args:
optim (Optimizer): An Optimizer instance. optim (Optimizer): An Optimizer instance.
module (ZeroDDP): A ``ZeroDDP`` instance. module (GeminiDDP): A ``GeminiDDP`` instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer. which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
...@@ -70,15 +71,15 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -70,15 +71,15 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000. growth_interval (float, optional): Growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2. hysteresis (float, optional): Hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32. max_scale (int, optional): Max_scale used by DynamicGradScaler. Defaults to 2**32.
clipping_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0. max_norm (float, optional): The norm value used to clip gradient. Defaults to 0.0.
norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0) norm_type (float, optional): The type of norm used for gradient clipping. Currently, only L2-norm (norm_type=2.0)
is supported in ZeroOptimizer. Defaults to 2.0. is supported in GeminiOptimizer. Defaults to 2.0.
verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False. verbose (bool, optional): Whether to print verbose information, including grad overflow info. Defaults to False.
""" """
def __init__(self, def __init__(self,
optim: Optimizer, optim: Optimizer,
module: ZeroDDP, module: GeminiDDP,
gpu_margin_mem_ratio: float = 0.0, gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
...@@ -87,12 +88,12 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -87,12 +88,12 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval: int = 1000, growth_interval: int = 1000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
clipping_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
verbose: bool = False, verbose: bool = False,
**defaults: Any): **defaults: Any):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ZeroDDP) assert isinstance(module, GeminiDDP)
assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \ assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
f"{_AVAIL_OPTIM_LIST}" f"{_AVAIL_OPTIM_LIST}"
self.module = module self.module = module
...@@ -101,8 +102,8 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -101,8 +102,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set() self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0 self.clipping_flag = max_norm > 0.0
self.max_norm = clipping_norm self.max_norm = max_norm
self.verbose = verbose self.verbose = verbose
self.param_groups_backup = list() self.param_groups_backup = list()
...@@ -111,7 +112,7 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -111,7 +112,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.id_to_fake_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict()
if self.clipping_flag: if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" assert norm_type == 2.0, "GeminiOptimizer only supports L2 norm now"
ddp_param_list = [] ddp_param_list = []
for name, param in module.named_parameters(): for name, param in module.named_parameters():
...@@ -735,8 +736,19 @@ class ZeroOptimizer(ColossalaiOptimizer): ...@@ -735,8 +736,19 @@ class ZeroOptimizer(ColossalaiOptimizer):
yield current_block, current_block_size yield current_block, current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
class GeminiAdamOptimizer(ZeroOptimizer): def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> torch.Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')
class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults) optimizer = HybridAdam(model.parameters(), **defaults)
......
...@@ -9,7 +9,7 @@ class MemStats(object): ...@@ -9,7 +9,7 @@ class MemStats(object):
def __init__(self) -> None: def __init__(self) -> None:
""" """
Store the non model data statistics used for Gemini and ZeroOptimizer. Store the non model data statistics used for Gemini and GeminiOptimizer.
""" """
# (preop_step, List[param]) # (preop_step, List[param])
self._step_param_dict = dict() self._step_param_dict = dict()
......
import functools import functools
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from time import time from time import time
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
...@@ -7,6 +8,7 @@ import torch ...@@ -7,6 +8,7 @@ import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector from .memory_tracer import ChunkMemStatsCollector
...@@ -17,7 +19,8 @@ class PlacementPolicy(ABC): ...@@ -17,7 +19,8 @@ class PlacementPolicy(ABC):
def __init__(self, def __init__(self,
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
**kwargs) -> None:
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
...@@ -25,57 +28,87 @@ class PlacementPolicy(ABC): ...@@ -25,57 +28,87 @@ class PlacementPolicy(ABC):
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
raise NotImplementedError raise NotImplementedError
@staticmethod @abstractmethod
def get_default_device() -> torch.device: def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
return torch.device('cpu') torch.device]) -> None:
raise NotImplementedError
class CPUPlacementPolicy(PlacementPolicy): class StaticPlacementPolicy(PlacementPolicy):
def __init__(self, def __init__(self,
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0')
offload_param_frac = 0.0
self.shard_param_frac = shard_param_frac
self.offload_optim_frac = offload_optim_frac
self.offload_param_frac = offload_param_frac
# these should be initialized in setup_grads_device
self.keep_gathered_chunk_mem = 0.0
self.keep_cuda_chunk_mem = 0.0
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
volume = 0 can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks)
start = time() can_offload_chunk_mem = can_shard_chunk_mem
for chunk in can_evict_chunks: for chunk in can_evict_chunks:
if can_shard_chunk_mem <= self.keep_gathered_chunk_mem:
break
self.chunk_manager.release_chunk(chunk) self.chunk_manager.release_chunk(chunk)
# real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem
can_shard_chunk_mem -= chunk.chunk_mem
for chunk in can_evict_chunks:
if can_offload_chunk_mem <= self.keep_cuda_chunk_mem:
break
self.chunk_manager.move_chunk(chunk, torch.device('cpu')) self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
volume += chunk.chunk_mem # real saved mem is shard_mem, for simplicity we use chunk_mem
return volume, time() - start can_offload_chunk_mem -= chunk.chunk_mem
return 0, 0.0
class CUDAPlacementPolicy(PlacementPolicy): def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
def __init__(self, total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params)
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' offloaded_optim_chunk_mem = 0
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) chunks = set(self.chunk_manager.get_chunk(p) for p in params)
for chunk in chunks:
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: params = chunk.get_tensors()
return 0, 0 # init offload optim settings
# keep gathered chunks are in CUDA
@staticmethod if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
def get_default_device() -> torch.device: device = get_current_device()
return get_current_device() else:
device = torch.device('cpu')
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
offloaded_optim_chunk_mem += chunk.chunk_mem
for p in params:
grads_device_map[p] = device
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
class AutoPlacementPolicy(PlacementPolicy): class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True need_mem_stats: bool = True
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
_warmup_non_model_data_ratio: float = 0.8
_steady_cuda_cap_ratio: float = 0.9
def __init__(self, def __init__(self,
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
self._warmup_non_model_data_ratio = warmup_non_model_data_ratio
self._steady_cuda_cap_ratio = steady_cuda_cap_ratio
def evict_tensors(self, def evict_tensors(self,
can_evict_chunks: List[Chunk], can_evict_chunks: List[Chunk],
...@@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy): ...@@ -105,11 +138,11 @@ class AutoPlacementPolicy(PlacementPolicy):
used_cuda_model_data = self.chunk_manager.total_mem['cuda'] used_cuda_model_data = self.chunk_manager.total_mem['cuda']
if warmup: if warmup:
# We designate a part of CUDA memory for model data in warmup iterations. # We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else: else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0 freed_cuda_model_data = 0
...@@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy): ...@@ -145,89 +178,22 @@ class AutoPlacementPolicy(PlacementPolicy):
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx] return [t for (t, idx) in next_compute_idx]
@staticmethod def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
def set_warmup_non_model_data_ratio(ratio: float) -> None: torch.device]) -> None:
ratio = float(ratio) for p in params:
assert 0.0 < ratio < 1.0 chunk = self.chunk_manager.get_chunk(p)
AutoPlacementPolicy._warmup_non_model_data_ratio = ratio # init offload optim settings
# keep gathered chunks are in CUDA
@staticmethod if chunk.keep_gathered:
def set_steady_cuda_cap_ratio(ratio: float) -> None: grads_device_map[p] = get_current_device()
ratio = float(ratio) else:
assert 0.0 < ratio < 1.0 grads_device_map[p] = torch.device('cpu')
AutoPlacementPolicy._steady_cuda_cap_ratio = ratio
class ConstPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = False
_accessed_memory_boundary = 512 * 1024**2
def __init__(self,
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
def evict_tensors(self,
can_evict_chunks: List[Chunk],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: Optional[List[Tuple[Chunk, ...]]] = None,
compute_idx: int = 0,
**kwargs) -> Tuple[int, float]:
"""
See the docstrings in the class `AutoPlacementPolicy`.
"""
start = time()
used_accessed_memory = self.chunk_manager.accessed_mem
avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory
freed_accessed_memory = 0
if avail_accessed_memory < cuda_demand:
to_free_memory = cuda_demand - avail_accessed_memory
to_free_chunks = can_evict_chunks
if not warmup:
# sort all chunks
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
for chunk in to_free_chunks:
if freed_accessed_memory >= to_free_memory:
break
self.chunk_manager.release_chunk(chunk)
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
freed_accessed_memory += chunk.chunk_mem
if freed_accessed_memory < to_free_memory:
raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! "
f"Need {to_free_memory}, freed {freed_accessed_memory}")
return freed_accessed_memory, time() - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
for i in range(len(compute_list) - 1, compute_idx, -1):
for chunk in compute_list[i]:
if chunk in next_compute_idx:
next_compute_idx[chunk] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
@staticmethod
def set_const_memory_boundary(cuda_memory_mb: int) -> None:
boundary = int(cuda_memory_mb * 1024**2)
assert boundary > 0
ConstPlacementPolicy._accessed_memory_boundary = boundary
class PlacementPolicyFactory: class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = { policies: Dict[str, Type[PlacementPolicy]] = {
'cpu': CPUPlacementPolicy,
'cuda': CUDAPlacementPolicy,
'auto': AutoPlacementPolicy, 'auto': AutoPlacementPolicy,
'const': ConstPlacementPolicy 'static': StaticPlacementPolicy,
} }
@staticmethod @staticmethod
...@@ -239,8 +205,3 @@ class PlacementPolicyFactory: ...@@ -239,8 +205,3 @@ class PlacementPolicyFactory:
@staticmethod @staticmethod
def get_policy_names(): def get_policy_names():
return tuple(PlacementPolicyFactory.policies.keys()) return tuple(PlacementPolicyFactory.policies.keys())
@staticmethod
def get_default_device(policy_name: str) -> torch.device:
policy_cls = PlacementPolicyFactory.create(policy_name)
return policy_cls.get_default_device()
...@@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model, ...@@ -64,13 +64,13 @@ def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=torch.float32, dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module: only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given ZeroDDP module. """Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original ZeroDDP model is not modified. You should notice that the original GeminiDDP model is not modified.
Thus, you can use the original model in further training. Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors. But you should not use the returned torch model to train, this can cause unexpected errors.
Args: Args:
zero_ddp_model (ZeroDDP): a zero ddp model zero_ddp_model (GeminiDDP): a zero ddp model
device (torch.device): the device of the final torch model device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the converted torch model only_rank_0 (bool): if True, only rank0 has the converted torch model
...@@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model, ...@@ -78,8 +78,8 @@ def get_static_torch_model(zero_ddp_model,
Returns: Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
""" """
from colossalai.zero.gemini.gemini_ddp import ZeroDDP from colossalai.zero.gemini.gemini_ddp import GeminiDDP
assert isinstance(zero_ddp_model, ZeroDDP) assert isinstance(zero_ddp_model, GeminiDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = zero_ddp_model.module colo_model = zero_ddp_model.module
......
...@@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module, ...@@ -109,6 +109,6 @@ def zero_optim_wrapper(model: nn.Module,
config_dict['clip_grad_norm'] = max_norm config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose) return LowLevelZeroOptimizer(optimizer, **config_dict, verbose=verbose)
else: else:
from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer from colossalai.zero.gemini.gemini_optimizer import GeminiOptimizer
config_dict['clipping_norm'] = max_norm config_dict['clipping_norm'] = max_norm
return ZeroOptimizer(optimizer, model, **config_dict, verbose=verbose) return GeminiOptimizer(optimizer, model, **config_dict, verbose=verbose)
...@@ -54,32 +54,38 @@ We also provide a lightweight chunk search mechanism to help users automatically ...@@ -54,32 +54,38 @@ We also provide a lightweight chunk search mechanism to help users automatically
We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management. We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management.
Also Make sure that your model is initialized under the context of ColoInitContext. Gemini allows LazyInitContext, which can save memory when initializing large models with multi-GPUs.
If your model has `N` billion parameters and your GPU memory is `M` GB, we recommend you use LazyInitContext when `4N >= M`. Otherwise, LazyInitContext is optional.
<!--- doc-test-ignore-start -->
```python ```python
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): with LazyInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True) model = gpt2_medium(checkpoint=True)
``` ```
<!--- doc-test-ignore-end -->
We've provided `Booster` API which is user-friendly. We recommend you use `Booster` API. But if you still want to use low level API, you can read below content of this section.
Define the model parameters as follows: Wrap the model with `GeminiDDP`.
<!--- doc-test-ignore-start -->
```python ```python
chunk_manager = init_chunk_manager(model=module, model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
``` ```
<!--- doc-test-ignore-end -->
`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_m` is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk. `hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_m` is a floating point, being the minimum chunk size divided by 2^20 (e.g., if min_chunk_size_m=2.5, then the minimum chunk size should be 2.5*(2^20)).If the aggregate size of parameters is still smaller than the minimum chunk size, all parameters will be compacted into one small chunk.
Initialization of the optimizer. Initialization of the optimizer.
<!--- doc-test-ignore-start -->
```python ```python
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
``` ```
<!--- doc-test-ignore-start -->
Training Training
<!--- doc-test-ignore-start -->
```python ```python
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(input_ids, attn_mask) outputs = model(input_ids, attn_mask)
...@@ -87,6 +93,7 @@ loss = criterion(outputs, input_ids) ...@@ -87,6 +93,7 @@ loss = criterion(outputs, input_ids)
optimizer.backward(loss) optimizer.backward(loss)
optimizer.step() optimizer.step()
``` ```
<!--- doc-test-ignore-start -->
> ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`. > ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`.
### Train GPT ### Train GPT
...@@ -142,46 +149,6 @@ class GPTLMLoss(nn.Module): ...@@ -142,46 +149,6 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
``` ```
Define tensor parallel and parameter sharding strategies for tensor parallelism:
```python
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg)
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
```
Write a function to get random inputs: Write a function to get random inputs:
...@@ -198,7 +165,7 @@ Finally, we define a model which uses Gemini + ZeRO DDP and define our training ...@@ -198,7 +165,7 @@ Finally, we define a model which uses Gemini + ZeRO DDP and define our training
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.zero import ColoInitContext from colossalai.lazy import LazyInitContext
from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
def main(): def main():
...@@ -214,17 +181,13 @@ def main(): ...@@ -214,17 +181,13 @@ def main():
optimizer = HybridAdam(model.parameters(), lr=0.001) optimizer = HybridAdam(model.parameters(), lr=0.001)
torch.manual_seed(123) torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree])
# build GPT model # build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): with ColoInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True) model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5) # Gemini + ZeRO DP
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
......
...@@ -53,32 +53,37 @@ ...@@ -53,32 +53,37 @@
我们将运用`GeminiDDP`的方式来使用基于Chunk内存管理的ZeRO。这是我们新包装的torch.Module ,它使用 ZeRO-DP 和 Gemini,其中ZeRO 用于并行,Gemini 用于内存管理。 我们将运用`GeminiDDP`的方式来使用基于Chunk内存管理的ZeRO。这是我们新包装的torch.Module ,它使用 ZeRO-DP 和 Gemini,其中ZeRO 用于并行,Gemini 用于内存管理。
同样需要确保你的模型是在 `ColoInitContext` 的上下文中初始化的。 Gemini支持惰性初始化, 它可以节省多卡初始化大模型时的显存使用.
如果你的模型有 `N` billion 个参数,你的 GPU 内存为 `M` GB, 当 `4N >= M` 时,我们推荐使用 LazyInitContext。否则,LazyInitContext 是可选的。
<!--- doc-test-ignore-start -->
```python ```python
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): with LazyInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True) model = gpt2_medium(checkpoint=True)
``` ```
<!--- doc-test-ignore-end -->
我们提供了 `Booster` API,它用户友好。我们推荐你使用 `Booster` API。如果您仍然想使用底层 API,您可以继续阅读本节其他内容。
定义模型参数如下: 使用 `GeminiDDP` 包装模型。
<!--- doc-test-ignore-start -->
```python ```python
chunk_manager = init_chunk_manager(model=module, model = GeminiDDP(model, hidden_dim=hidden_dim, min_chunk_size_m=min_chunk_size_m)
init_device=device,
hidden_dim=hidden_dim,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
``` ```
<!--- doc-test-ignore-end -->
`hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_m`是以兆(2^20)为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。 `hidden dim`是DNN的隐藏维度。用户可以提供这个参数来加快搜索速度。如果用户在训练前不知道这个参数也可以。 我们将使用默认值 1024。`min_chunk_size_m`是以兆(2^20)为单位的最小块大小。如果参数的总大小仍然小于最小块大小,则所有参数将被压缩为一个小块。
初始化优化器。 初始化优化器。
<!--- doc-test-ignore-start -->
```python ```python
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
``` ```
<!--- doc-test-ignore-end -->
<!--- doc-test-ignore-start -->
训练 训练
```python ```python
optimizer.zero_grad() optimizer.zero_grad()
...@@ -87,6 +92,7 @@ loss = criterion(outputs, input_ids) ...@@ -87,6 +92,7 @@ loss = criterion(outputs, input_ids)
optimizer.backward(loss) optimizer.backward(loss)
optimizer.step() optimizer.step()
``` ```
<!--- doc-test-ignore-end -->
> ⚠️ 注意:请不要使用`loss.backward()`,规范写法是`optimizer.backward(loss)`。 > ⚠️ 注意:请不要使用`loss.backward()`,规范写法是`optimizer.backward(loss)`。
### 训练GPT ### 训练GPT
...@@ -143,47 +149,6 @@ class GPTLMLoss(nn.Module): ...@@ -143,47 +149,6 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
``` ```
定义张量并行和参数分片策略:
```python
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg)
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg)
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg)
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
```
写一个获得随机输入的函数: 写一个获得随机输入的函数:
```python ```python
...@@ -200,7 +165,7 @@ def get_data(batch_size, seq_len, vocab_size): ...@@ -200,7 +165,7 @@ def get_data(batch_size, seq_len, vocab_size):
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.zero import ColoInitContext from colossalai.lazy import LazyInitContext
from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
def main(): def main():
...@@ -216,17 +181,13 @@ def main(): ...@@ -216,17 +181,13 @@ def main():
optimizer = HybridAdam(model.parameters(), lr=0.001) optimizer = HybridAdam(model.parameters(), lr=0.001)
torch.manual_seed(123) torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree])
# build GPT model # build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg): with ColoInitContext(default_device=torch.device('cuda')):
model = gpt2_medium(checkpoint=True) model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5) # Gemini + ZeRO DP
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
......
...@@ -22,7 +22,7 @@ from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wra ...@@ -22,7 +22,7 @@ from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wra
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer from colossalai.zero import GeminiOptimizer
def main(): def main():
...@@ -46,7 +46,7 @@ def main(): ...@@ -46,7 +46,7 @@ def main():
args.local_rank = -1 args.local_rank = -1
args.log_interval = 1 args.log_interval = 1
else: else:
colossalai.launch_from_torch(config={}) #args.colossal_config colossalai.launch_from_torch(config={}) # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"]) args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info( logger.info(
f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
...@@ -123,7 +123,8 @@ def main(): ...@@ -123,7 +123,8 @@ def main():
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
# 144003367 is is the length of the entire dataset # 144003367 is is the length of the entire dataset
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) # len(dataloader)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
total_steps = steps_per_epoch * args.epoch total_steps = steps_per_epoch * args.epoch
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
......
...@@ -20,6 +20,5 @@ for plugin in "gemini"; do ...@@ -20,6 +20,5 @@ for plugin in "gemini"; do
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--lr_warmup_steps=0 \ --lr_warmup_steps=0 \
--test_run=True \ --test_run=True \
--num_class_images=200 \ --num_class_images=200
--placement="auto" # "cuda"
done done
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment