Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dfc88b85
Unverified
Commit
dfc88b85
authored
May 11, 2022
by
Ziyue Jiang
Committed by
GitHub
May 11, 2022
Browse files
[Tensor] simplify named param (#928)
* simplify ColoModulize * simplify ColoModulize * polish * polish
parent
32a45cd7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
35 deletions
+14
-35
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+4
-33
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+10
-2
No files found.
colossalai/utils/model/colo_init_context.py
View file @
dfc88b85
...
...
@@ -90,56 +90,28 @@ def ColoModulize(module):
Replacing the parameters() and named_parameters() with our customized ones
"""
def
named_params_with_colotensor
(
module
:
nn
.
Module
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
,
)
->
Iterator
[
Tuple
[
str
,
Union
[
nn
.
Parameter
,
ColoTensor
]]]:
modules
=
module
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
module
)]
memo
=
set
()
for
mod_prefix
,
mod
in
modules
:
# find all colotensors tensor params
for
name
,
val
in
vars
(
mod
).
items
():
if
isinstance
(
val
,
ColoTensor
)
and
val
not
in
memo
:
memo
.
add
(
val
)
name
=
mod_prefix
+
(
'.'
if
mod_prefix
else
''
)
+
name
yield
name
,
val
# find all nn.Parameters
for
name
,
val
in
module
.
old_named_parameters
(
recurse
=
recurse
):
yield
name
,
val
def
fake_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
for
p
in
module
.
old_parameters
(
*
args
,
**
kargs
):
if
isinstance
(
p
,
ColoTensor
):
yield
p
.
torch_tensor
()
elif
isinstance
(
p
,
torch
.
Tensor
):
yield
p
def
fake_named_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
module
.
old_named_parameters
(
*
args
,
**
kargs
):
if
isinstance
(
p
,
ColoTensor
):
yield
name
,
p
.
torch_tensor
()
elif
isinstance
(
p
,
torch
.
Tensor
):
yield
name
,
p
def
colo_parameters
(
self
,
*
args
,
**
kargs
):
for
_
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
yield
p
def
colo_named_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
yield
name
,
p
module
.
old_named_parameters
=
module
.
named_parameters
module
.
old_parameters
=
module
.
parameters
funcType
=
types
.
MethodType
module
.
parameters
=
funcType
(
fake_parameters
,
module
)
module
.
named_parameters
=
funcType
(
fake_named_parameters
,
module
)
module
.
colo_parameters
=
funcType
(
c
ol
o
_parameters
,
module
)
module
.
colo_named_parameters
=
funcType
(
c
ol
o
_named_parameters
,
module
)
module
.
colo_parameters
=
module
.
ol
d
_parameters
module
.
colo_named_parameters
=
module
.
ol
d
_named_parameters
module
.
_colo_visited
=
True
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
...
...
@@ -154,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self
.
_lazy_memory_allocate
=
lazy_memory_allocate
self
.
_device
=
device
# TODO(jzy) replace it with old __setattr__ in the exit() of context?
torch
.
nn
.
Module
.
__setattr__
=
_setattr_with_colotensor
torch
.
nn
.
Module
.
register_parameter
=
_register_parameter_with_colotensor
...
...
tests/test_tensor/test_model.py
View file @
dfc88b85
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
tests.components_to_test.registry
import
non_distributed_component_funcs
import
colossalai
...
...
@@ -371,7 +372,13 @@ def _run_pretrain_load():
dict_col
=
{}
for
name
,
param
in
model_pretrained
.
named_parameters
():
dict_pretrained
[
name
]
=
param
for
name
,
param
in
model
.
named_parameters
():
c1
=
0
c2
=
0
for
name
,
param
in
model
.
colo_named_parameters
():
if
isinstance
(
param
,
ColoParameter
):
c1
=
c1
+
1
else
:
c2
=
c2
+
1
dict_col
[
name
]
=
param
for
name
,
param
in
dict_pretrained
.
items
():
...
...
@@ -416,4 +423,5 @@ if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
# test_model()
_test_pretrain_load
(
4
)
# _test_pretrain_load(4)
_run_pretrain_load
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment