Unverified Commit 250be4d3 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[utils] integrated colotensor with lazy init context (#1324)

* [utils] integrated colotensor with lazy init context

* polish code

* polish code

* polish code
parent 659a7407
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# coding: utf-8 # coding: utf-8
import torch import torch
from colossalai.tensor import ColoParameter import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
import types import types
import inspect import inspect
import typing
from typing import List, Callable from typing import List, Callable
from colossalai.utils.model.utils import substitute_init_recursively from colossalai.utils.model.utils import substitute_init_recursively
import copy
class LazyInitContext(): class LazyInitContext():
...@@ -18,8 +18,7 @@ class LazyInitContext(): ...@@ -18,8 +18,7 @@ class LazyInitContext():
Note: Note:
This API is only experimental and subject to future changes. This API is only experimental and subject to future changes.
It should be integrated with meta tensor initialization in the future.
Usage: Usage:
with LazyInitContext() as ctx: with LazyInitContext() as ctx:
model = nn.Linear(10, 10) model = nn.Linear(10, 10)
...@@ -36,14 +35,17 @@ class LazyInitContext(): ...@@ -36,14 +35,17 @@ class LazyInitContext():
assert not model.weight.is_meta and torch.all(model.weight == 0) assert not model.weight.is_meta and torch.all(model.weight == 0)
Args: Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
""" """
tensor_set_value_func = ['zero_'] tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, extra_torch_tensor_func: List[str] = None): def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None):
self._intercepted_init_func_cache = [] # TODO: hijack the torch constructor functions as well
self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {}
self._nn_init_methods = self._get_nn_init_methods() self._nn_init_methods = self._get_nn_init_methods()
self._torch_mod_cls = torch.nn.modules.module.Module self._torch_mod_cls = torch.nn.modules.module.Module
...@@ -53,14 +55,20 @@ class LazyInitContext(): ...@@ -53,14 +55,20 @@ class LazyInitContext():
else: else:
self._torch_tensor_funcs = self.tensor_set_value_func self._torch_tensor_funcs = self.tensor_set_value_func
def _cache_func(self, func): @property
def to_meta(self):
return self._to_meta
def _cache_init_func(self, func):
""" """
This method wraps the ``torch.nn.init`` method so that the function call This method wraps the ``torch.nn.init`` method and torch tensor value-setting functions
is cached instead of being executed. so that the function call is cached instead of being executed.
""" """
def wrapped_init_func(*args, **kwargs): def wrapped_init_func(tensor, *args, **kwargs):
self._intercepted_init_func_cache.append(dict(func=func, args=args, kwargs=kwargs)) if tensor not in self._intercepted_nn_init_func_cache:
self._intercepted_nn_init_func_cache[tensor] = []
self._intercepted_nn_init_func_cache[tensor].append((func, args, kwargs))
return wrapped_init_func return wrapped_init_func
...@@ -76,17 +84,10 @@ class LazyInitContext(): ...@@ -76,17 +84,10 @@ class LazyInitContext():
for name in nn_init_method_names: for name in nn_init_method_names:
nn_init_methods.append((name, getattr(torch.nn.init, name))) nn_init_methods.append((name, getattr(torch.nn.init, name)))
def _has_tensor_in_arg(func):
hints = typing.get_type_hints(func)
for k, v in hints.items():
if v is torch.Tensor:
return True
return False
def _is_init_method(item): def _is_init_method(item):
name, func = item name, func = item
if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')
or not _has_tensor_in_arg(func)): if (not isinstance(func, types.FunctionType) or name.startswith('_') or not name.endswith('_')):
return False return False
else: else:
return True return True
...@@ -103,11 +104,13 @@ class LazyInitContext(): ...@@ -103,11 +104,13 @@ class LazyInitContext():
has_device = 'device' in inspect.signature(func).parameters has_device = 'device' in inspect.signature(func).parameters
def layer_lazy_init(module, *args, **kwargs): def layer_lazy_init(module, *args, **kwargs):
self._intercepted_init_func_cache.append( # if this module contains device argument
dict(func=func, module=module, args=args, kwargs=copy.deepcopy(kwargs))) # we set it to meta to initialize as meta backend
if has_device: if has_device:
kwargs['device'] = 'meta' kwargs['device'] = 'meta'
func(module, *args, **kwargs) func(module, *args, **kwargs)
# if device is not found, we intialize it and convert to meta
if not has_device: if not has_device:
module.to('meta') module.to('meta')
...@@ -122,7 +125,7 @@ class LazyInitContext(): ...@@ -122,7 +125,7 @@ class LazyInitContext():
def _patch_nn_init_funcs(self): def _patch_nn_init_funcs(self):
# patch nn.init functions # patch nn.init functions
for name, func in self._nn_init_methods: for name, func in self._nn_init_methods:
setattr(torch.nn.init, name, self._cache_func(func)) setattr(torch.nn.init, name, self._cache_init_func(func))
def _unpatch_nn_init_funcs(self): def _unpatch_nn_init_funcs(self):
# unpatch nn.init functions # unpatch nn.init functions
...@@ -150,7 +153,7 @@ class LazyInitContext(): ...@@ -150,7 +153,7 @@ class LazyInitContext():
origin_func_name = self._get_tmp_origin_func_ref(func_name) origin_func_name = self._get_tmp_origin_func_ref(func_name)
origin_func = getattr(torch.Tensor, func_name) origin_func = getattr(torch.Tensor, func_name)
setattr(torch.Tensor, origin_func_name, origin_func) setattr(torch.Tensor, origin_func_name, origin_func)
setattr(torch.Tensor, func_name, self._cache_func(origin_func)) setattr(torch.Tensor, func_name, self._cache_init_func(origin_func))
def _unpatch_torch_tensor_funcs(self): def _unpatch_torch_tensor_funcs(self):
for func_name in self._torch_tensor_funcs: for func_name in self._torch_tensor_funcs:
...@@ -159,17 +162,18 @@ class LazyInitContext(): ...@@ -159,17 +162,18 @@ class LazyInitContext():
setattr(torch.Tensor, func_name, origin_func) setattr(torch.Tensor, func_name, origin_func)
def __enter__(self): def __enter__(self):
self._patch_submodule_init() self._patch_torch_tensor_funcs()
self._patch_nn_init_funcs()
if self._to_meta:
self._patch_submodule_init()
return self return self
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
self._unpatch_submodule_init() if self._to_meta:
# build model_rebuild_dict in reverse order to make sure get correct init func for inherited class. self._unpatch_submodule_init()
self.module_rebuild_dict = {} self._unpatch_nn_init_funcs()
self._intercepted_init_func_cache.reverse() self._unpatch_torch_tensor_funcs()
for cache in self._intercepted_init_func_cache:
self.module_rebuild_dict[cache['module']] = (cache['func'], cache['args'], cache['kwargs'])
self._intercepted_init_func_cache.reverse()
def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None): def lazy_init_parameters(self, model: torch.nn.Module, device='cpu', call_back: Callable = None):
""" """
...@@ -178,80 +182,56 @@ class LazyInitContext(): ...@@ -178,80 +182,56 @@ class LazyInitContext():
Args: Args:
model (`torch.nn.Module`): the model instantiated under the context. model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized device (str): the device on which weights are initialized
""" """
# build param mapping
param_id_to_name = dict() def _init_recursively(module: nn.Module):
for name, param in model.named_parameters(): # recursively initialize the module
param_id_to_name[id(param)] = name for mod in module.children():
for name, buffer in model.named_buffers(): _init_recursively(mod)
param_id_to_name[id(buffer)] = name
# initialize and shard tensors directly attached to the current module
assert model in self.module_rebuild_dict, 'We only support rebuild modules which intercepted during initializing by us.' for name, param in module.named_parameters(recurse=False):
_init_and_shard(module, name, param)
def _process_arg(arg):
""" for name, buf in module.named_buffers(recurse=False):
Process args recursively. If arg is a torch.nn.Module instance in module_rebuild_dict, _init_and_shard(module, name, buf)
we need to rebuild it with real parameters. If arg is a tuple or list, we will process
the element of arg with this function again. @torch.no_grad()
""" def _init_and_shard(module, name, tensor):
if torch.is_tensor(arg): # check whether the tensor is a buffer or parameter
tensor_id = id(arg) is_param = isinstance(tensor, nn.parameter.Parameter)
if tensor_id in param_id_to_name:
arg = _replace_meta_param_with_real_param(arg) # get sharding spec
dist_spec = getattr(tensor, 'dist_spec', None)
elif isinstance(arg, torch.nn.Module): pg = getattr(tensor, 'pg', None)
if arg in self.module_rebuild_dict:
arg = self.lazy_init_parameters(model=arg, device=device, call_back=call_back) # convert the tensor from meta to materialized one
if tensor.is_meta:
elif isinstance(arg, (tuple, list)): materialized_tensor = torch.empty_like(tensor, device=device)
rst_list = [] # if this tensor is a meta tensor, it must have an init function
for element in arg: assert tensor in self._intercepted_nn_init_func_cache
processed_element = _process_arg(element) tensor = materialized_tensor
rst_list.append(processed_element)
arg = rst_list # apply init function
return arg if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
def _replace_meta_param_with_real_param(meta_param): init_func(tensor, *args, **kwargs)
if meta_param.device != 'meta':
return meta_param # convert it to ColoTensor or ColoParameter
tensor_id = id(meta_param) if is_param:
param_full_name = param_id_to_name[tensor_id] tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad)
real_param = torch.empty_like(meta_param, dtype=meta_param.dtype, device=device)
real_param = ColoParameter(real_param, requires_grad=meta_param.requires_grad)
if '.' in param_full_name:
submodule_name, param_name = param_full_name.rsplit('.', 1)
submodule = model.get_submodule(submodule_name)
else:
submodule = model
param_name = param_full_name
setattr(submodule, param_name, real_param)
# execute call_back function on the materailized tensor
# this can where sharding comes in
if call_back:
call_back(real_param)
return real_param
func, args, kwargs = self.module_rebuild_dict[model]
args = list(args)
# check args for parameter replacement
for idx, arg in enumerate(args):
arg = _process_arg(arg)
args[idx] = arg
# check kwargs for parameter replacement
for arg_name, arg in kwargs.items():
if arg_name == 'device':
arg = device
else: else:
arg = _process_arg(arg) tensor = ColoTensor.from_torch_tensor(tensor)
kwargs[arg_name] = arg
# apply sharding
if dist_spec:
tensor = tensor.redistribute(dist_spec=dist_spec, pg=pg)
# override the original tensor
with torch.no_grad():
setattr(module, name, tensor)
# build user specified model _init_recursively(model)
with torch.no_grad():
func(model, *args, **kwargs)
return model return model
...@@ -10,27 +10,42 @@ np.random.seed(MANUAL_SEED) ...@@ -10,27 +10,42 @@ np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED)
def test_lazy_init(): def test_lazy_init_with_meta():
cpu_rng_state = torch.get_rng_state() ctx = LazyInitContext(to_meta=True)
origin_model = resnet34(num_classes=10)
origin_param_dict = dict(origin_model.named_parameters())
torch.set_rng_state(cpu_rng_state)
ctx = LazyInitContext()
with ctx: with ctx:
model = resnet34(num_classes=10) model = resnet34(num_classes=10)
for param in model.parameters(): for param in model.parameters():
assert param.is_meta assert param.is_meta
for buffer in model.buffers(): for buffer in model.buffers():
assert buffer.is_meta assert buffer.is_meta
ctx.lazy_init_parameters(model) ctx.lazy_init_parameters(model)
for name, param in model.named_parameters():
assert not param.is_meta, name
for buffer in model.buffers():
assert not buffer.is_meta
def test_lazy_init_without_meta():
ctx = LazyInitContext(to_meta=False)
with ctx:
model = resnet34(num_classes=10)
for param in model.parameters(): for param in model.parameters():
assert not param.is_meta assert not param.is_meta
for buffer in model.buffers(): for buffer in model.buffers():
assert not buffer.is_meta assert not buffer.is_meta
param_dict = dict(model.named_parameters())
for key in origin_param_dict.keys(): conv1_weight_before_init = model.conv1.weight.clone()
assert origin_param_dict[key].data.equal(param_dict[key].data) ctx.lazy_init_parameters(model)
conv1_weight_after_init = model.conv1.weight.clone()
assert not torch.allclose(conv1_weight_after_init, conv1_weight_before_init)
if __name__ == '__main__': if __name__ == '__main__':
test_lazy_init() test_lazy_init_with_meta()
test_lazy_init_without_meta()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment