Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e6ec99d3
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "527758b2aef56436556b8df79e32a88fd7f01e9b"
Unverified
Commit
e6ec99d3
authored
Nov 10, 2022
by
Frank Lee
Committed by
GitHub
Nov 10, 2022
Browse files
[utils] fixed lazy init context (#1867)
parent
50c4cb01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
24 deletions
+34
-24
colossalai/utils/model/lazy_init_context.py
colossalai/utils/model/lazy_init_context.py
+17
-14
tests/test_fx/test_complete_workflow.py
tests/test_fx/test_complete_workflow.py
+17
-10
No files found.
colossalai/utils/model/lazy_init_context.py
View file @
e6ec99d3
#!/usr/bin/env python
# coding: utf-8
import
inspect
import
types
from
typing
import
Callable
,
List
import
torch
import
torch.nn
as
nn
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
import
types
import
inspect
from
typing
import
List
,
Callable
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
from
colossalai.utils.model.utils
import
substitute_init_recursively
class
LazyInitContext
():
"""
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
A context to allow for lazy weight initialization of PyTorch modules. It intercepts the tensor
initialization functions for lazy initialization
Note:
This API is only experimental and subject to future changes.
This API is only experimental and subject to future changes.
Usage:
with LazyInitContext() as ctx:
...
...
@@ -30,19 +31,20 @@ class LazyInitContext():
# initialize weights
ctx.lazy_init_parameters(model)
# make sure the weight is not a meta tensor
# make sure the weight is not a meta tensor
# and initialized correctly
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
tensor_set_value_func
=
[
'zero_'
,
'fill_'
]
def
__init__
(
self
,
to_meta
:
bool
=
Fals
e
,
extra_torch_tensor_func
:
List
[
str
]
=
None
):
def
__init__
(
self
,
to_meta
:
bool
=
Tru
e
,
extra_torch_tensor_func
:
List
[
str
]
=
None
):
# TODO: hijack the torch constructor functions as well
self
.
_to_meta
=
to_meta
self
.
_intercepted_nn_init_func_cache
=
{}
...
...
@@ -212,18 +214,19 @@ class LazyInitContext():
materialized_tensor
=
torch
.
empty_like
(
tensor
,
device
=
device
)
# if this tensor is a meta tensor, it must have an init function
assert
tensor
in
self
.
_intercepted_nn_init_func_cache
tensor
=
materialized_tensor
else
:
materialized_tensor
=
tensor
# apply init function
if
tensor
in
self
.
_intercepted_nn_init_func_cache
:
init_func
,
args
,
kwargs
=
self
.
_intercepted_nn_init_func_cache
[
tensor
][
-
1
]
init_func
(
tensor
,
*
args
,
**
kwargs
)
init_func
(
materialized_
tensor
,
*
args
,
**
kwargs
)
# convert it to ColoTensor or ColoParameter
if
is_param
:
tensor
=
ColoParameter
.
from_torch_tensor
(
tensor
,
requires_grad
=
tensor
.
requires_grad
)
tensor
=
ColoParameter
.
from_torch_tensor
(
materialized_
tensor
,
requires_grad
=
tensor
.
requires_grad
)
else
:
tensor
=
ColoTensor
.
from_torch_tensor
(
tensor
)
tensor
=
ColoTensor
.
from_torch_tensor
(
materialized_
tensor
)
# override the original tensor
with
torch
.
no_grad
():
...
...
tests/test_fx/test_complete_workflow.py
View file @
e6ec99d3
import
colossalai
import
torch
import
torch.nn
as
nn
from
functools
import
partial
import
pytest
import
torch
.multiprocessing
as
mp
import
torch
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
from
functools
import
partial
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
colossalai
from
colossalai.fx
import
ColoTracer
from
colossalai.utils.model.lazy_init_context
import
LazyInitContext
from
colossalai.fx.passes.shard_1d_pass
import
transformer_mlp_pass
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.model.lazy_init_context
import
LazyInitContext
class
MLP
(
torch
.
nn
.
Module
):
...
...
@@ -35,6 +37,9 @@ def run_workflow(world_size):
with
LazyInitContext
()
as
ctx
:
model
=
MLP
(
16
)
for
param
in
model
.
parameters
():
assert
param
.
is_meta
# tracing
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
)
...
...
@@ -46,6 +51,8 @@ def run_workflow(world_size):
# materialization and sharding
ctx
.
lazy_init_parameters
(
annotated_gm
)
for
param
in
model
.
parameters
():
assert
not
param
.
is_meta
# # check sharding
assert
list
(
model
.
linear1
.
weight
.
shape
)
==
[
16
//
world_size
,
16
]
...
...
@@ -57,7 +64,7 @@ def run_workflow(world_size):
data
=
torch
.
rand
(
4
,
16
)
non_fx_out
=
model
(
data
)
fx_out
=
annotated_gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
,
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
def
run_dist
(
rank
,
world_size
,
port
):
...
...
@@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
if
__name__
==
'__main__'
:
test_complete_workflow
(
2
)
test_complete_workflow
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment