Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bad5d4c0
Unverified
Commit
bad5d4c0
authored
Jun 10, 2022
by
Frank Lee
Committed by
GitHub
Jun 10, 2022
Browse files
[context] support lazy init of module (#1088)
* [context] support lazy init of module * polish code
parent
be01db37
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
246 additions
and
0 deletions
+246
-0
colossalai/utils/model/lazy_init_context.py
colossalai/utils/model/lazy_init_context.py
+223
-0
tests/test_utils/test_lazy_init_ctx.py
tests/test_utils/test_lazy_init_ctx.py
+23
-0
No files found.
colossalai/utils/model/lazy_init_context.py
0 → 100644
View file @
bad5d4c0
#!/usr/bin/env python
# coding: utf-8
import
torch
from
colossalai.tensor
import
ColoParameter
import
types
import
inspect
import
typing
from
typing
import
List
,
Callable
class
LazyInitContext
():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
It should be integrated with meta tensor initialization in the future.
Usage:
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model.weight.zero_()
# make sure the weight is a meta tensor
assert model.weight.is_meta
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func
=
[
'zero_'
]
def
__init__
(
self
,
extra_torch_tensor_func
:
List
[
str
]
=
None
):
self
.
_intercepted_init_func_cache
=
[]
self
.
_nn_init_methods
=
self
.
_get_nn_init_methods
()
self
.
_torch_mod_cls
=
torch
.
nn
.
modules
.
module
.
Module
if
extra_torch_tensor_func
:
# use tuple to remove duplicates
self
.
_torch_tensor_funcs
=
tuple
(
self
.
tensor_set_value_func
+
extra_torch_tensor_func
)
else
:
self
.
_torch_tensor_funcs
=
self
.
tensor_set_value_func
def
_cache_func
(
self
,
func
):
"""
This method wraps the ``torch.nn.init`` method so that the function call
is cached instead of being executed.
"""
def
wrapped_init_func
(
*
args
,
**
kwargs
):
self
.
_intercepted_init_func_cache
.
append
(
dict
(
func
=
func
,
args
=
args
,
kwargs
=
kwargs
))
return
wrapped_init_func
def
_get_nn_init_methods
(
self
):
"""
This method looks for all available functions in the ``torch.nn.init``
module.
"""
nn_init_method_names
=
dir
(
torch
.
nn
.
init
)
nn_init_methods
=
[]
# look for all methods in ``torch.nn.init`` module
for
name
in
nn_init_method_names
:
nn_init_methods
.
append
((
name
,
getattr
(
torch
.
nn
.
init
,
name
)))
def
_has_tensor_in_arg
(
func
):
hints
=
typing
.
get_type_hints
(
torch
.
nn
.
init
.
normal_
)
for
k
,
v
in
hints
.
items
():
if
v
is
torch
.
Tensor
:
return
True
return
False
def
_is_init_method
(
item
):
name
,
func
=
item
if
(
not
isinstance
(
func
,
types
.
FunctionType
)
or
name
.
startswith
(
'_'
)
or
not
name
.
endswith
(
'_'
)
or
not
_has_tensor_in_arg
(
func
)):
return
False
else
:
return
True
# remove methods which are not init functions
nn_init_methods
=
list
(
filter
(
_is_init_method
,
nn_init_methods
))
return
nn_init_methods
def
_wrap_module_init
(
self
,
func
):
"""
This method wraps the calls to the `__init__` of ``torch.nn.Module`` and replaces
the argument device with value 'meta' so that all modules are created as meta tensors.
"""
has_device
=
'device'
in
inspect
.
signature
(
func
).
parameters
def
layer_lazy_init
(
*
args
,
**
kwargs
):
if
has_device
:
kwargs
[
'device'
]
=
'meta'
func
(
*
args
,
**
kwargs
)
return
layer_lazy_init
def
_get_tmp_origin_func_ref
(
self
,
name
):
"""
Generate a function name for consistency during caching and retrieving.
"""
return
f
'_orig_
{
name
}
'
def
_patch_nn_init_funcs
(
self
):
# patch nn.init functions
for
name
,
func
in
self
.
_nn_init_methods
:
setattr
(
torch
.
nn
.
init
,
name
,
self
.
_cache_func
(
func
))
def
_unpatch_nn_init_funcs
(
self
):
# unpatch nn.init functions
for
name
,
func
in
self
.
_nn_init_methods
:
setattr
(
torch
.
nn
.
init
,
name
,
func
)
def
_patch_submodule_init
(
self
):
# patch classes __init__ methods
for
sub_cls
in
self
.
_torch_mod_cls
.
__subclasses__
():
sub_cls
.
__orig_init__
=
sub_cls
.
__init__
sub_cls
.
__init__
=
self
.
_wrap_module_init
(
sub_cls
.
__init__
)
def
_unpatch_submodule_init
(
self
):
for
sub_cls
in
self
.
_torch_mod_cls
.
__subclasses__
():
sub_cls
.
__init__
=
sub_cls
.
__orig_init__
def
_patch_torch_tensor_funcs
(
self
):
# patch tensor value-setting functions
for
func_name
in
self
.
_torch_tensor_funcs
:
origin_func_name
=
self
.
_get_tmp_origin_func_ref
(
func_name
)
origin_func
=
getattr
(
torch
.
Tensor
,
func_name
)
setattr
(
torch
.
Tensor
,
origin_func_name
,
origin_func
)
setattr
(
torch
.
Tensor
,
func_name
,
self
.
_cache_func
(
origin_func
))
def
_unpatch_torch_tensor_funcs
(
self
):
for
func_name
in
self
.
_torch_tensor_funcs
:
origin_func_name
=
self
.
_get_tmp_origin_func_ref
(
func_name
)
origin_func
=
getattr
(
torch
.
Tensor
,
origin_func_name
)
setattr
(
torch
.
Tensor
,
func_name
,
origin_func
)
def
__enter__
(
self
):
self
.
_patch_nn_init_funcs
()
self
.
_patch_torch_tensor_funcs
()
self
.
_patch_submodule_init
()
return
self
def
__exit__
(
self
,
*
args
,
**
kwargs
):
self
.
_unpatch_submodule_init
()
self
.
_unpatch_torch_tensor_funcs
()
self
.
_unpatch_nn_init_funcs
()
def
lazy_init_parameters
(
self
,
model
:
torch
.
nn
.
Module
,
device
=
'cpu'
,
call_back
:
Callable
=
None
):
"""
Initialize the weights of the meta-tensor model.
Args:
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized
"""
# build param mapping
param_id_to_name
=
dict
()
for
name
,
param
in
model
.
named_parameters
():
param_id_to_name
[
id
(
param
)]
=
name
def
_replace_meta_param_with_real_param
(
meta_param
):
tensor_id
=
id
(
meta_param
)
param_full_name
=
param_id_to_name
[
tensor_id
]
real_param
=
torch
.
empty_like
(
meta_param
,
dtype
=
meta_param
.
dtype
,
device
=
device
)
real_param
=
ColoParameter
(
real_param
,
requires_grad
=
meta_param
.
requires_grad
)
if
'.'
in
param_full_name
:
submodule_name
,
param_name
=
param_full_name
.
rsplit
(
'.'
,
1
)
submodule
=
model
.
get_submodule
(
submodule_name
)
else
:
submodule
=
model
param_name
=
param_full_name
setattr
(
submodule
,
param_name
,
real_param
)
# execute call_back function on the materailized tensor
# this can where sharding comes in
if
call_back
:
call_back
(
real_param
)
return
real_param
# build modules
for
cache
in
self
.
_intercepted_init_func_cache
:
func
=
cache
[
'func'
]
args
=
list
(
cache
[
'args'
])
kwargs
=
cache
[
'kwargs'
]
# check args for parameter replacement
for
idx
,
arg
in
enumerate
(
args
):
if
torch
.
is_tensor
(
arg
):
tensor_id
=
id
(
arg
)
if
tensor_id
not
in
param_id_to_name
:
continue
else
:
arg
=
_replace_meta_param_with_real_param
(
arg
)
args
[
idx
]
=
arg
# check kwargs for parameter replacement
for
arg_name
,
arg
in
enumerate
(
kwargs
):
if
torch
.
is_tensor
(
arg
):
tensor_id
=
id
(
arg
)
if
tensor_id
not
in
param_id_to_name
:
continue
else
:
arg
=
_replace_meta_param_with_real_param
(
arg
)
kwargs
[
arg_name
]
=
arg
with
torch
.
no_grad
():
func
(
*
args
,
**
kwargs
)
tests/test_utils/test_lazy_init_ctx.py
0 → 100644
View file @
bad5d4c0
import
torch
import
torch.nn
as
nn
from
colossalai.utils.model.lazy_init_context
import
LazyInitContext
def
test_lazy_init_ctx
():
with
LazyInitContext
()
as
ctx
:
model
=
nn
.
Linear
(
10
,
10
)
model
.
weight
.
zero_
()
# make sure the weight is a meta tensor
assert
model
.
weight
.
is_meta
# initialize weights
ctx
.
lazy_init_parameters
(
model
)
# make sure the weight is not a meta tensor
# and initialized correctly
assert
not
model
.
weight
.
is_meta
and
torch
.
all
(
model
.
weight
==
0
)
if
__name__
==
'__main__'
:
test_lazy_init_ctx
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment