Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e43f83aa
Unverified
Commit
e43f83aa
authored
Apr 26, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 26, 2022
Browse files
[Tensor] get named parameters for model using ColoTensors (#874)
parent
28830402
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
3 deletions
+59
-3
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+5
-3
colossalai/tensor/utils.py
colossalai/tensor/utils.py
+51
-0
tests/test_tensor/test_net_tp.py
tests/test_tensor/test_net_tp.py
+3
-0
No files found.
colossalai/tensor/__init__.py
View file @
e43f83aa
...
...
@@ -2,8 +2,10 @@ from .spec import ComputePattern, ParallelAction, TensorSpec
from
.op_wrapper
import
(
colo_op_impl
,)
from
.colo_tensor
import
ColoTensor
from
.utils
import
convert_parameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
._ops
import
*
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
]
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'colo_op_impl'
,
'ComputePattern'
,
'TensorSpec'
,
'ParallelAction'
,
'named_params_with_colotensor'
]
colossalai/tensor/utils.py
View file @
e43f83aa
...
...
@@ -2,6 +2,57 @@ import torch
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
typing
import
Iterator
,
Tuple
,
Union
import
torch.nn
as
nn
from
colossalai.tensor
import
ColoTensor
# The function is credited to PyTorch Team
def
named_params_with_colotensor
(
module
:
nn
.
Module
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
,
)
->
Iterator
[
Tuple
[
str
,
Union
[
nn
.
Parameter
,
ColoTensor
]]]:
r
"""Returns an iterator over module parameters (together with the
ColoTensor parameters), yielding both the name of the parameter
as well as the parameter itself. This is typically passed to a
:class:torchshard._shard.sharded_optim.ShardedOptimizer
Args:
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
(string, Union[Tensor, ColoTensor]): Tuple containing
the name and parameter (or ColoTensor parameter)
Example::
>>> model = torch.nn.Linear(*linear_size)
>>> delattr(model.weight)
>>> setattr(model.weight, ColoTensor(...))
>>> for name, param in named_params_with_colotensor(model):
>>> if name in ['weight']:
>>> print(param.size())
"""
modules
=
module
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
module
)]
memo
=
set
()
for
mod_prefix
,
mod
in
modules
:
# find all sharded tensor params
for
name
,
val
in
vars
(
mod
).
items
():
if
isinstance
(
val
,
ColoTensor
)
and
val
not
in
memo
:
memo
.
add
(
val
)
name
=
mod_prefix
+
(
'.'
if
mod_prefix
else
''
)
+
name
yield
name
,
val
# find all nn.Parameters
for
name
,
val
in
module
.
named_parameters
():
yield
name
,
val
def
_convert_tensor
(
tensor
:
torch
.
Tensor
)
->
ColoTensor
:
return
ColoTensor
(
tensor
)
...
...
tests/test_tensor/test_net_tp.py
View file @
e43f83aa
...
...
@@ -7,6 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
ColoInitContext
from
colossalai.tensor
import
named_params_with_colotensor
from
functools
import
partial
...
...
@@ -19,6 +20,8 @@ def run_simple_net():
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
named_params_with_colotensor
(
model
):
print
(
param
)
# we set the Specs for weight of each linear.
# model.proj1.weight.set_spec('1Drow')
# model.proj2.weight.set_spec('1Drow')
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment