Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e6ec99d3
Unverified
Commit
e6ec99d3
authored
Nov 10, 2022
by
Frank Lee
Committed by
GitHub
Nov 10, 2022
Browse files
[utils] fixed lazy init context (#1867)
parent
50c4cb01
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
24 deletions
+34
-24
colossalai/utils/model/lazy_init_context.py
colossalai/utils/model/lazy_init_context.py
+17
-14
tests/test_fx/test_complete_workflow.py
tests/test_fx/test_complete_workflow.py
+17
-10
No files found.
colossalai/utils/model/lazy_init_context.py
View file @
e6ec99d3
#!/usr/bin/env python
#!/usr/bin/env python
# coding: utf-8
# coding: utf-8
import
inspect
import
types
from
typing
import
Callable
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
import
types
from
colossalai.tensor
import
ColoParameter
,
ColoTensor
import
inspect
from
typing
import
List
,
Callable
from
colossalai.utils.model.utils
import
substitute_init_recursively
from
colossalai.utils.model.utils
import
substitute_init_recursively
...
@@ -35,14 +36,15 @@ class LazyInitContext():
...
@@ -35,14 +36,15 @@ class LazyInitContext():
assert not model.weight.is_meta and torch.all(model.weight == 0)
assert not model.weight.is_meta and torch.all(model.weight == 0)
Args:
Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False.
to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related
extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
"""
"""
tensor_set_value_func
=
[
'zero_'
,
'fill_'
]
tensor_set_value_func
=
[
'zero_'
,
'fill_'
]
def
__init__
(
self
,
to_meta
:
bool
=
Fals
e
,
extra_torch_tensor_func
:
List
[
str
]
=
None
):
def
__init__
(
self
,
to_meta
:
bool
=
Tru
e
,
extra_torch_tensor_func
:
List
[
str
]
=
None
):
# TODO: hijack the torch constructor functions as well
# TODO: hijack the torch constructor functions as well
self
.
_to_meta
=
to_meta
self
.
_to_meta
=
to_meta
self
.
_intercepted_nn_init_func_cache
=
{}
self
.
_intercepted_nn_init_func_cache
=
{}
...
@@ -212,18 +214,19 @@ class LazyInitContext():
...
@@ -212,18 +214,19 @@ class LazyInitContext():
materialized_tensor
=
torch
.
empty_like
(
tensor
,
device
=
device
)
materialized_tensor
=
torch
.
empty_like
(
tensor
,
device
=
device
)
# if this tensor is a meta tensor, it must have an init function
# if this tensor is a meta tensor, it must have an init function
assert
tensor
in
self
.
_intercepted_nn_init_func_cache
assert
tensor
in
self
.
_intercepted_nn_init_func_cache
tensor
=
materialized_tensor
else
:
materialized_tensor
=
tensor
# apply init function
# apply init function
if
tensor
in
self
.
_intercepted_nn_init_func_cache
:
if
tensor
in
self
.
_intercepted_nn_init_func_cache
:
init_func
,
args
,
kwargs
=
self
.
_intercepted_nn_init_func_cache
[
tensor
][
-
1
]
init_func
,
args
,
kwargs
=
self
.
_intercepted_nn_init_func_cache
[
tensor
][
-
1
]
init_func
(
tensor
,
*
args
,
**
kwargs
)
init_func
(
materialized_
tensor
,
*
args
,
**
kwargs
)
# convert it to ColoTensor or ColoParameter
# convert it to ColoTensor or ColoParameter
if
is_param
:
if
is_param
:
tensor
=
ColoParameter
.
from_torch_tensor
(
tensor
,
requires_grad
=
tensor
.
requires_grad
)
tensor
=
ColoParameter
.
from_torch_tensor
(
materialized_
tensor
,
requires_grad
=
tensor
.
requires_grad
)
else
:
else
:
tensor
=
ColoTensor
.
from_torch_tensor
(
tensor
)
tensor
=
ColoTensor
.
from_torch_tensor
(
materialized_
tensor
)
# override the original tensor
# override the original tensor
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
tests/test_fx/test_complete_workflow.py
View file @
e6ec99d3
import
colossalai
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
pytest
import
pytest
import
torch
.multiprocessing
as
mp
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.testing
import
rerun_if_address_is_in_use
import
torch.multiprocessing
as
mp
from
functools
import
partial
import
torch.nn
as
nn
import
colossalai
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.utils.model.lazy_init_context
import
LazyInitContext
from
colossalai.fx.passes.shard_1d_pass
import
transformer_mlp_pass
from
colossalai.fx.passes.shard_1d_pass
import
transformer_mlp_pass
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils.model.lazy_init_context
import
LazyInitContext
class
MLP
(
torch
.
nn
.
Module
):
class
MLP
(
torch
.
nn
.
Module
):
...
@@ -35,6 +37,9 @@ def run_workflow(world_size):
...
@@ -35,6 +37,9 @@ def run_workflow(world_size):
with
LazyInitContext
()
as
ctx
:
with
LazyInitContext
()
as
ctx
:
model
=
MLP
(
16
)
model
=
MLP
(
16
)
for
param
in
model
.
parameters
():
assert
param
.
is_meta
# tracing
# tracing
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
)
graph
=
tracer
.
trace
(
model
)
...
@@ -46,6 +51,8 @@ def run_workflow(world_size):
...
@@ -46,6 +51,8 @@ def run_workflow(world_size):
# materialization and sharding
# materialization and sharding
ctx
.
lazy_init_parameters
(
annotated_gm
)
ctx
.
lazy_init_parameters
(
annotated_gm
)
for
param
in
model
.
parameters
():
assert
not
param
.
is_meta
# # check sharding
# # check sharding
assert
list
(
model
.
linear1
.
weight
.
shape
)
==
[
16
//
world_size
,
16
]
assert
list
(
model
.
linear1
.
weight
.
shape
)
==
[
16
//
world_size
,
16
]
...
@@ -57,7 +64,7 @@ def run_workflow(world_size):
...
@@ -57,7 +64,7 @@ def run_workflow(world_size):
data
=
torch
.
rand
(
4
,
16
)
data
=
torch
.
rand
(
4
,
16
)
non_fx_out
=
model
(
data
)
non_fx_out
=
model
(
data
)
fx_out
=
annotated_gm
(
data
)
fx_out
=
annotated_gm
(
data
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
assert
torch
.
equal
(
non_fx_out
,
fx_out
)
,
f
'
{
non_fx_out
}
vs
{
fx_out
}
'
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
@@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
...
@@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_complete_workflow
(
2
)
test_complete_workflow
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment