Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a203b709
Unverified
Commit
a203b709
authored
Sep 06, 2022
by
ver217
Committed by
GitHub
Sep 06, 2022
Browse files
[hotfix] fix init context (#1543)
* fix init context * fix lazy init ctx
parent
64169f3e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
colossalai/utils/model/lazy_init_context.py
colossalai/utils/model/lazy_init_context.py
+8
-8
colossalai/utils/model/utils.py
colossalai/utils/model/utils.py
+7
-5
No files found.
colossalai/utils/model/lazy_init_context.py
View file @
a203b709
...
@@ -15,7 +15,7 @@ class LazyInitContext():
...
@@ -15,7 +15,7 @@ class LazyInitContext():
"""
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
initialization functions for lazy initialization
Note:
Note:
This API is only experimental and subject to future changes.
This API is only experimental and subject to future changes.
...
@@ -23,17 +23,17 @@ class LazyInitContext():
...
@@ -23,17 +23,17 @@ class LazyInitContext():
with LazyInitContext() as ctx:
with LazyInitContext() as ctx:
model = nn.Linear(10, 10)
model = nn.Linear(10, 10)
model.weight.zero_()
model.weight.zero_()
# make sure the weight is a meta tensor
# make sure the weight is a meta tensor
assert model.weight.is_meta
assert model.weight.is_meta
# initialize weights
# initialize weights
ctx.lazy_init_parameters(model)
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# make sure the weight is not a meta tensor
# and initialized correctly
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
extra_torch_tensor_func (List[str]): extra torch tensor functions related
...
@@ -138,14 +138,14 @@ class LazyInitContext():
...
@@ -138,14 +138,14 @@ class LazyInitContext():
cls
.
__orig_init__
=
cls
.
__init__
cls
.
__orig_init__
=
cls
.
__init__
cls
.
__init__
=
self
.
_wrap_module_init
(
cls
.
__init__
)
cls
.
__init__
=
self
.
_wrap_module_init
(
cls
.
__init__
)
substitute_init_recursively
(
self
.
_torch_mod_cls
,
_activate_wrap_init
)
substitute_init_recursively
(
self
.
_torch_mod_cls
,
_activate_wrap_init
,
set
()
)
def
_unpatch_submodule_init
(
self
):
def
_unpatch_submodule_init
(
self
):
def
_recover_orig_init
(
cls
):
def
_recover_orig_init
(
cls
):
cls
.
__init__
=
cls
.
__orig_init__
cls
.
__init__
=
cls
.
__orig_init__
substitute_init_recursively
(
self
.
_torch_mod_cls
,
_recover_orig_init
)
substitute_init_recursively
(
self
.
_torch_mod_cls
,
_recover_orig_init
,
set
()
)
def
_patch_torch_tensor_funcs
(
self
):
def
_patch_torch_tensor_funcs
(
self
):
# patch tensor value-setting functions
# patch tensor value-setting functions
...
@@ -178,7 +178,7 @@ class LazyInitContext():
...
@@ -178,7 +178,7 @@ class LazyInitContext():
def
lazy_init_parameters
(
self
,
model
:
torch
.
nn
.
Module
,
device
=
'cpu'
):
def
lazy_init_parameters
(
self
,
model
:
torch
.
nn
.
Module
,
device
=
'cpu'
):
"""
"""
Initialize the weights of the meta-tensor model.
Initialize the weights of the meta-tensor model.
Args:
Args:
model (`torch.nn.Module`): the model instantiated under the context.
model (`torch.nn.Module`): the model instantiated under the context.
device (str): the device on which weights are initialized
device (str): the device on which weights are initialized
...
...
colossalai/utils/model/utils.py
View file @
a203b709
...
@@ -3,10 +3,12 @@ import functools
...
@@ -3,10 +3,12 @@ import functools
from
typing
import
Optional
from
typing
import
Optional
def
substitute_init_recursively
(
cls
,
func
):
def
substitute_init_recursively
(
cls
,
func
,
visited
:
set
):
for
subcls
in
cls
.
__subclasses__
():
for
subcls
in
cls
.
__subclasses__
():
substitute_init_recursively
(
subcls
,
func
)
substitute_init_recursively
(
subcls
,
func
,
visited
)
func
(
subcls
)
if
subcls
not
in
visited
:
func
(
subcls
)
visited
.
add
(
subcls
)
def
call_to_str
(
base
,
*
args
,
**
kwargs
):
def
call_to_str
(
base
,
*
args
,
**
kwargs
):
...
@@ -64,7 +66,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
...
@@ -64,7 +66,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function.
# Excution self._post_init_method after the default init function.
substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_enable_class
)
substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_enable_class
,
set
()
)
# holding on to the current __init__subclass__ for exit
# holding on to the current __init__subclass__ for exit
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
=
(
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
)
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
=
(
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
)
...
@@ -87,7 +89,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
...
@@ -87,7 +89,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
cls
.
__init__
=
cls
.
_old_init
cls
.
__init__
=
cls
.
_old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module
# Replace .__init__() for all existing subclasses of torch.nn.Module
substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_disable_class
)
substitute_init_recursively
(
torch
.
nn
.
modules
.
module
.
Module
,
_disable_class
,
set
()
)
# Replace .__init__() for future subclasses of torch.nn.Module
# Replace .__init__() for future subclasses of torch.nn.Module
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
=
(
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
)
torch
.
nn
.
modules
.
module
.
Module
.
__init_subclass__
=
(
torch
.
nn
.
modules
.
module
.
Module
.
_old_init_subclass
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment