Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
26c49639
Unverified
Commit
26c49639
authored
Apr 27, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 27, 2022
Browse files
[Tensor] overriding paramters() for Module using ColoTensor (#889)
parent
daf59ff7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
6 deletions
+74
-6
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-1
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+67
-4
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+1
-1
No files found.
colossalai/tensor/colo_tensor.py
View file @
26c49639
...
...
@@ -165,7 +165,12 @@ class ColoTensor(object):
self
.
_torch_tensor
.
backward
(
gradient
=
gradient
,
retain_graph
=
retain_graph
)
def
__add__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
if
isinstance
(
o
,
ColoTensor
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
elif
isinstance
(
o
,
torch
.
Tensor
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
)
else
:
raise
TypeError
(
f
'
{
type
(
o
)
}
is not supported in ColoTensor __add__'
)
def
__truediv__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
/
o
)
...
...
colossalai/utils/model/colo_init_context.py
View file @
26c49639
from
colossalai.utils.cuda
import
get_current_device
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
# from colossalai.logging import get_dist_logger
from
colossalai.tensor
import
ColoTensor
import
types
# _orig_torch_empty = torch.empty
from
torch
import
nn
from
typing
import
Iterator
,
Tuple
,
Union
def
ColoModulize
(
module
):
"""
Replacing the parameters() and named_parameters() with our customized ones
"""
def
named_params_with_colotensor
(
module
:
nn
.
Module
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
,
)
->
Iterator
[
Tuple
[
str
,
Union
[
nn
.
Parameter
,
ColoTensor
]]]:
modules
=
module
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
module
)]
memo
=
set
()
for
mod_prefix
,
mod
in
modules
:
# find all colotensors tensor params
for
name
,
val
in
vars
(
mod
).
items
():
if
isinstance
(
val
,
ColoTensor
)
and
val
not
in
memo
:
memo
.
add
(
val
)
name
=
mod_prefix
+
(
'.'
if
mod_prefix
else
''
)
+
name
yield
name
,
val
# find all nn.Parameters
for
name
,
val
in
module
.
old_named_parameters
(
recurse
=
recurse
):
yield
name
,
val
def
fake_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
if
isinstance
(
p
,
ColoTensor
):
yield
p
.
torch_tensor
()
elif
isinstance
(
p
,
torch
.
Tensor
):
yield
p
def
fake_named_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
if
isinstance
(
p
,
ColoTensor
):
yield
name
,
p
.
torch_tensor
()
elif
isinstance
(
p
,
torch
.
Tensor
):
yield
name
,
p
def
colo_parameters
(
self
,
*
args
,
**
kargs
):
for
_
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
yield
p
def
colo_named_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
yield
name
,
p
module
.
old_named_parameters
=
module
.
named_parameters
module
.
old_parameters
=
module
.
parameters
funcType
=
types
.
MethodType
module
.
parameters
=
funcType
(
fake_parameters
,
module
)
module
.
named_parameters
=
funcType
(
fake_named_parameters
,
module
)
module
.
colo_parameters
=
funcType
(
colo_parameters
,
module
)
module
.
colo_named_parameters
=
funcType
(
colo_named_parameters
,
module
)
module
.
_colo_visited
=
True
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
...
...
@@ -24,8 +82,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
if
hasattr
(
module
,
'_colo_visited'
):
return
name_list
=
[]
for
name
,
param
in
module
.
named_parameters
():
for
name
,
param
in
module
.
named_parameters
(
recurse
=
False
):
if
isinstance
(
param
,
ColoTensor
):
continue
name_list
.
append
((
name
,
param
))
...
...
@@ -35,3 +96,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr
(
module
,
name
)
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
to
(
self
.
_device
),
save_payload
=
save_torch_payload
))
ColoModulize
(
module
)
tests/test_tensor/test_model.py
View file @
26c49639
...
...
@@ -48,7 +48,7 @@ def run_1d_row_tp():
model_torch
=
model_torch
.
cuda
()
# A naive way to set spec for all weights in Linear
for
name
,
p
in
named_params_with_colotensor
(
model
):
for
name
,
p
in
model
.
colo_named_parameters
(
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment