Unverified Commit c622bb36 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #3915 from FrankLeeeee/update/develop

[sync] update develop with main
parents 34966378 9c88b6cb
......@@ -123,7 +123,7 @@ class WorkerBase(ABC):
self.device = device
self._initialize_outstanding_range()
# variable and const for context managment
# variable and const for context management
self.outstanding = 0
self.forward_times = 0
self.backward_times = 0
......@@ -226,7 +226,7 @@ class WorkerBase(ABC):
self.pp_rank_to_worker_rref = pp_rank_to_worker_rref
# for some schedule need the other worker's info to initialise partition (like Chimera)
# construction of partition is executed after the registion of pp_rank_to_worker_rref
# construction of partition is executed after the registration of pp_rank_to_worker_rref
self._initialize_partition()
# res_use works for lifecycle counter,
......@@ -418,7 +418,7 @@ class WorkerBase(ABC):
# On current PP middleware design for DAG, get_output_by_key used by _subscribe_producer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the _subscribe_producer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
# lock of work_item queue operation guarantees the consistency of lifecycle counter.
work_item_from_producer = self._subscribe_producer(microbatch_id, forward_only)
self.work_list[key] = work_item_from_producer
self.work_list_condition_lock.notify_all()
......@@ -460,7 +460,7 @@ class WorkerBase(ABC):
# On current PP middleware design for DAG, get_output_by_key used by subscribe_consumer
# can only be executed once for every producer-consumer stage pair, which is necessary
# to count the lifecycle of work_item. So, keeping the subscribe_consumer in the same
# lock of work_item queue operation gurantees the consistency of lifecycle counter.
# lock of work_item queue operation guarantees the consistency of lifecycle counter.
work_item_from_consumer = self._subscribe_consumer(microbatch_id)
self.work_list[key] = work_item_from_consumer
self.work_list_condition_lock.notify_all()
......@@ -508,7 +508,7 @@ class WorkerBase(ABC):
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
# should be arranged in order, the order of the input of current forward
self.producer_stage_ids = self.get_producer_stage_ids()
self.consumer_stage_ids = self.get_consumer_stage_ids()
......
......@@ -123,7 +123,7 @@ class ChimeraWorker(WorkerBase):
assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed"
assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed"
# should be aranged in order, the order of the input of current forward
# should be arranged in order, the order of the input of current forward
self.producer_stage_ids = []
self.consumer_stage_ids = []
......@@ -174,7 +174,7 @@ class ChimeraWorker(WorkerBase):
else:
# if it is down pipeline, create partition by origin method
co_up_pp_worker_rref = self.pp_rank_to_worker_rref[pp_rank - stage_num]
# get the coresponding model state dict and wait for its init
# get the corresponding model state dict and wait for its init
state_dict = co_up_pp_worker_rref.rpc_sync().get_partition_state_dict()
super()._initialize_partition()
self.module_partition.load_state_dict(state_dict)
......@@ -228,7 +228,7 @@ class ChimeraWorker(WorkerBase):
stage_num = self.actual_stage_num
co_pp_rank = (pp_rank + stage_num) % (2 * stage_num)
# if currrent pp_rank is not the first to do step
# if current pp_rank is not the first to do step
# wait its previous pp_rank finish step
grads = self.get_parameter_gradients()
......
......@@ -113,7 +113,7 @@ def _binary_search(weights, num):
def partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
"Layer length should be divided by the number of chunks, otherwise parameter method is recommended"
logger = get_dist_logger()
parts = [[] for _ in range(pipeline_parallel_size)]
......
......@@ -28,7 +28,7 @@ class CommSpec:
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
......
......@@ -41,7 +41,7 @@ class DimSpec:
def _convert_str_to_shard_list(self, str_spec):
'''
Conver str_spec into shard_list.
Convert str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
......@@ -58,7 +58,7 @@ class DimSpec:
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
......
......@@ -164,7 +164,7 @@ def _get_grad_args(*args):
for obj in args:
if _is_grad_tensor(obj):
return args, None
# otherwise, the first arguement should be a tuple of grad tensors
# otherwise, the first argument should be a tuple of grad tensors
# if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered
arg_zero = args[0]
if not isinstance(arg_zero, tuple):
......
......@@ -130,7 +130,7 @@ class ProcessGroup:
@property
def has_cpu_groups(self) -> bool:
"""has_cpu_groups
If cpu groups have been initailized.
If cpu groups have been initialized.
Returns:
bool: cpu process groups have been initialized or not.
......
......@@ -252,7 +252,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict):
'''
Get all valid sharding specs from source_spec with single shard operation, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
For the sharding operation, we just care about legal sharding dimensions.
Argument:
......@@ -386,7 +386,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with one step transform, and
accumulate commucation cost on origin cost which will finally be used in auto sharding solver.
accumulate communication cost on origin cost which will finally be used in auto sharding solver.
Note:
all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
......@@ -577,7 +577,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
Step3:
Repeat above steps until the source spec transform to target spec.
During finding the transform path, commucation cost will be accumulated, and it
During finding the transform path, communication cost will be accumulated, and it
will be finally used in auto parallel solver.
Additionally, to avoid repeating the path search in runtime, we cached all solved path
......
......@@ -45,7 +45,7 @@ class _DimSpec:
def _convert_str_to_shard_list(self, str_spec):
'''
Conver str_spec into shard_list.
Convert str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
......@@ -62,7 +62,7 @@ class _DimSpec:
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
Build a difference mapping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
......@@ -166,7 +166,7 @@ class ShardingSpec:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
and the value of the key describe which logical axis will be sharded in that dimension.
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
......
......@@ -77,7 +77,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
and the second element describes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
shard_list_list = []
......
#!/usr/bin/env python
# coding: utf-8
import inspect
import types
from typing import Callable, List
import torch
import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.utils.model.utils import substitute_init_recursively
class LazyInitContext():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
Usage:
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model.weight.zero_()
# make sure the weight is a meta tensor
assert model.weight.is_meta
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
self._nn_init_methods = self._get_nn_init_methods()
self._torch_mod_cls = torch.nn.modules.module.Module
if extra_torch_tensor_func:
# use tuple to remove duplicates
self._torch_tensor_funcs = tuple(self.tensor_set_value_func + extra_torch_tensor_func)
else:
self._torch_tensor_funcs = self.tensor_set_value_func
@property
def to_meta(self):
return self._to_meta
def _cache_init_func(self, func):
"""
This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
so that the function call is cached instead of being executed.
"""
def wrapped_init_func(tensor, *args, **kwargs):
if tensor not in self._intercepted_nn_init_func_cache:
self._intercepted_nn_init_func_cache[tensor] = []
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
return wrapped_init_func
def _get_nn_init_methods(self):
"""
This method looks for all available functions in the ``torch.nn.init``
module.
"""
nn_init_method_names = dir(torch.nn.init)
nn_init_methods = []
# look for all methods in ``torch.nn.init`` module
for name in nn_init_method_names:
nn_init_methods.append((name, getattr(torch.nn.init, name)))
def _is_init_method(item):
name, func = item
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
return False
else:
return True
# remove methods which are not init functions
nn_init_methods = list(filter(_is_init_method, nn_init_methods))
return nn_init_methods
def _wrap_module_init(self, func):
"""
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
the argument device with value 'meta' so that all modules are created as meta tensors.
"""
has_device = 'device' in inspect.signature(func).parameters
def layer_lazy_init(module, *args, **kwargs):
# if this module contains device argument
# we set it to meta to initialize as meta backend
if has_device:
kwargs['device'] = 'meta'
func(module, *args, **kwargs)
# if device is not found, we intialize it and convert to meta
if not has_device:
module.to('meta')
return layer_lazy_init
def _get_tmp_origin_func_ref(self, name):
"""
Generate a function name for consistency during caching and retrieving.
"""
return f'_orig_{name}'
def _patch_nn_init_funcs(self):
# patch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, self._cache_init_func(func))
def _unpatch_nn_init_funcs(self):
# unpatch nn.init functions
for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, func)
def _patch_submodule_init(self):
# patch classes __init__ methods
def _activate_wrap_init(cls):
cls.__orig_init__ = cls.__init__
cls.__init__ = self._wrap_module_init(cls.__init__)
substitute_init_recursively(self._torch_mod_cls, _activate_wrap_init, set())
def _unpatch_submodule_init(self):
def _recover_orig_init(cls):
cls.__init__ = cls.__orig_init__
substitute_init_recursively(self._torch_mod_cls, _recover_orig_init, set())
def _patch_torch_tensor_funcs(self):
# patch tensor value-setting functions
for func_name in self._torch_tensor_funcs:
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, func_name)
setattr(torch.Tensor, origin_func_name, origin_func)
setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
def _unpatch_torch_tensor_funcs(self):
for func_name in self._torch_tensor_funcs:
origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, origin_func_name)
setattr(torch.Tensor, func_name, origin_func)
def __enter__(self):
self._patch_torch_tensor_funcs()
self._patch_nn_init_funcs()
if self._to_meta:
self._patch_submodule_init()
return self
def __exit__(self, *args, **kwargs):
if self._to_meta:
self._unpatch_submodule_init()
self._unpatch_nn_init_funcs()
self._unpatch_torch_tensor_funcs()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu'):
"""
Initialize the weights of the meta-tensor model.
Args:
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized
"""
def _init_recursively(module: nn.Module):
# recursively initialize the module
for mod in module.children():
_init_recursively(mod)
# initialize and shard tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
_init_and_shard(module, name, param)
for name, buf in module.named_buffers(recurse=False):
_init_and_shard(module, name, buf)
@torch.no_grad()
def _init_and_shard(module, name, tensor):
# check whether the tensor is a buffer or parameter
is_param = isinstance(tensor, nn.parameter.Parameter)
# get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
pg = getattr(tensor, 'pg', None)
comp_spec = getattr(tensor, 'comp_spec', None)
# convert the tensor from meta to materialized one
if tensor.is_meta:
materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache
else:
materialized_tensor = tensor
# apply init function
if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(materialized_tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter
if is_param:
tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
else:
tensor = ColoTensor.from_torch_tensor(materialized_tensor)
# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)
# apply sharding
if dist_spec:
tensor.process_group = pg
tensor.set_tensor_spec(dist_spec, comp_spec)
_init_recursively(model)
return model
......@@ -2,13 +2,14 @@ import itertools
from collections import OrderedDict
from contextlib import nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
......@@ -16,7 +17,6 @@ from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.utils.model.experimental import LazyTensor
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
......@@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""
def __init__(self,
......@@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True) -> None:
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
......@@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self._logger = get_dist_logger()
......@@ -96,11 +100,14 @@ class ZeroDDP(ColoDDP):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
def _get_non_persistent_buffers_set(self,
module,
memo: Optional[Set[nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
......@@ -115,16 +122,17 @@ class ZeroDDP(ColoDDP):
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
self_non_persistent_set = set(
map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix,
remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self):
"""This function is only triggered for inference.
"""
......@@ -147,7 +155,7 @@ class ZeroDDP(ColoDDP):
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
self.module.zero_grad(set_to_none=True)
if not grad_flag:
outputs = self._inference_forward(*args, **kwargs)
......@@ -566,14 +574,14 @@ class ZeroDDP(ColoDDP):
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
......@@ -609,7 +617,7 @@ class ZeroDDP(ColoDDP):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
buffer.data = buffer.to(self.mixed_precision)
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
......@@ -732,6 +740,7 @@ class GeminiDDP(ZeroDDP):
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
......@@ -772,5 +781,10 @@ class GeminiDDP(ZeroDDP):
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
scatter_after_inference)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple
import torch
......@@ -9,7 +8,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
......@@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
module: ZeroDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.module = module
def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0
def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
class ZeroOptimizer(ColossalaiOptimizer):
......@@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
......@@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.__init__optimizer()
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
if module.mixed_precision is torch.float16:
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
elif module.mixed_precision is torch.bfloat16:
self.mix_precision_mixin = BF16MixedPrecisionMixin()
else:
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
......@@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
for chunk16 in self.chunk16_set:
chunk16.optim_update()
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter)
# all-reduce across global group
dist.all_reduce(self._found_overflow)
return self._found_overflow.item() > 0
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
......@@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer):
return global_norm
def _get_combined_scale(self):
loss_scale = 1
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
div_scale = self.mix_precision_mixin.get_grad_div_scale()
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
div_scale = clip * div_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
return -1 if div_scale == 1.0 else div_scale
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = 0
self.mix_precision_mixin.pre_zero_grad()
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()
found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
if self.mix_precision_mixin.should_skip_step():
if self.verbose:
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
......@@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
......@@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
......@@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self):
......
......@@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2
......@@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
......@@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
bf16: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
super().__init__(default_dtype=default_dtype)
......@@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.bf16 = bf16
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
......@@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
return t.to(half_dtype) if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
......@@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data)
buffer.data = cast_fn(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):
......
......@@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
return tensor.float()
return tensor
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.bfloat16()
return tensor
def apply_to_tensors(x: Any, fn: Callable):
if torch.is_tensor(x):
return fn(x)
......
......@@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
from ._utils import (
cast_float_arguments,
cast_tensor_to_bf16,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
......@@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""
def __init__(self,
......@@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16
# We force users to use ZeroInitContext
for submodule in module.modules():
......@@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations(*args)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs
......
......@@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
self.bf16 = sharded_model.bf16
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
......@@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
# Store fp32 param shards
self._register_master_weight()
......@@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._zero_grad()
def backward(self, loss: Tensor) -> None:
if not self.bf16:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
......@@ -175,23 +179,26 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
if not self.bf16:
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
# unscale grads if scaled
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
self._maybe_move_fp32_shards()
if not self.bf16:
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
......@@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
state[k] = v.cuda()
def _prepare_grads(self):
if self._grad_prepared:
return
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
......@@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.set_data_none()
self._grad_prepared = True
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
......@@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
......
......@@ -6,7 +6,11 @@ import torch
import torch.distributed as dist
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
......@@ -27,6 +31,31 @@ from ._utils import (
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
class LowLevelZeroOptimizer(ColossalaiOptimizer):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
......@@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose)
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
# gradient clipping
self._clip_grad_norm = clip_grad_norm
......@@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
@property
def dtype(self):
return self._dtype
@property
def loss_scale(self):
return self.grad_scaler.scale
@property
def num_param_groups(self):
return len(self._working_param_groups)
......@@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
################################
def backward(self, loss, retain_graph=False, sync_grad=True):
loss = self.loss_scale * loss
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
# finish gradient reduction
......@@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
if self.mixed_precision_mixin is not None:
self.mixed_precision_mixin.pre_zero_grad()
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
......@@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
# check for overflow
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
......@@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# Mixed Precision Utilities #
#############################
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group_id in range(len(self._working_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
# all-reduce over model parallel group
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
if self._found_overflow.item() > 0:
return True
else:
return False
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
div_scale = 1.0
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
combined_scale = clip * self.loss_scale
div_scale = clip * div_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
grad.data.mul_(1. / div_scale)
############################
# Gradient Synchronization #
......
......@@ -98,7 +98,7 @@ Lastly, if you want to skip some code, you just need to add the following annota
<!--- doc-test-ignore-end -->
```
If you have any dependency required, please add it to `requriements-doc-test.txt` for pip and `conda-doc-test-deps.yml` for Conda.
If you have any dependency required, please add it to `requirements-doc-test.txt` for pip and `conda-doc-test-deps.yml` for Conda.
### 💉 Auto Documentation
......
# References
The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few reserach works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format.
The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few research works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format.
## By Our Team
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment