Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
26c49639
Unverified
Commit
26c49639
authored
Apr 27, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 27, 2022
Browse files
[Tensor] overriding paramters() for Module using ColoTensor (#889)
parent
daf59ff7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
6 deletions
+74
-6
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-1
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+67
-4
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+1
-1
No files found.
colossalai/tensor/colo_tensor.py
View file @
26c49639
...
@@ -165,7 +165,12 @@ class ColoTensor(object):
...
@@ -165,7 +165,12 @@ class ColoTensor(object):
self
.
_torch_tensor
.
backward
(
gradient
=
gradient
,
retain_graph
=
retain_graph
)
self
.
_torch_tensor
.
backward
(
gradient
=
gradient
,
retain_graph
=
retain_graph
)
def
__add__
(
self
,
o
)
->
"ColoTensor"
:
def
__add__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
if
isinstance
(
o
,
ColoTensor
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
elif
isinstance
(
o
,
torch
.
Tensor
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
)
else
:
raise
TypeError
(
f
'
{
type
(
o
)
}
is not supported in ColoTensor __add__'
)
def
__truediv__
(
self
,
o
)
->
"ColoTensor"
:
def
__truediv__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
/
o
)
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
/
o
)
...
...
colossalai/utils/model/colo_init_context.py
View file @
26c49639
from
colossalai.utils.cuda
import
get_current_device
from
.utils
import
InsertPostInitMethodToModuleSubClasses
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
import
torch
# from colossalai.logging import get_dist_logger
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
import
types
# _orig_torch_empty = torch.empty
from
torch
import
nn
from
typing
import
Iterator
,
Tuple
,
Union
def
ColoModulize
(
module
):
"""
Replacing the parameters() and named_parameters() with our customized ones
"""
def
named_params_with_colotensor
(
module
:
nn
.
Module
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
,
)
->
Iterator
[
Tuple
[
str
,
Union
[
nn
.
Parameter
,
ColoTensor
]]]:
modules
=
module
.
named_modules
(
prefix
=
prefix
)
if
recurse
else
[(
prefix
,
module
)]
memo
=
set
()
for
mod_prefix
,
mod
in
modules
:
# find all colotensors tensor params
for
name
,
val
in
vars
(
mod
).
items
():
if
isinstance
(
val
,
ColoTensor
)
and
val
not
in
memo
:
memo
.
add
(
val
)
name
=
mod_prefix
+
(
'.'
if
mod_prefix
else
''
)
+
name
yield
name
,
val
# find all nn.Parameters
for
name
,
val
in
module
.
old_named_parameters
(
recurse
=
recurse
):
yield
name
,
val
def
fake_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
if
isinstance
(
p
,
ColoTensor
):
yield
p
.
torch_tensor
()
elif
isinstance
(
p
,
torch
.
Tensor
):
yield
p
def
fake_named_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
if
isinstance
(
p
,
ColoTensor
):
yield
name
,
p
.
torch_tensor
()
elif
isinstance
(
p
,
torch
.
Tensor
):
yield
name
,
p
def
colo_parameters
(
self
,
*
args
,
**
kargs
):
for
_
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
yield
p
def
colo_named_parameters
(
self
,
*
args
,
**
kargs
):
for
name
,
p
in
named_params_with_colotensor
(
self
,
*
args
,
**
kargs
):
yield
name
,
p
module
.
old_named_parameters
=
module
.
named_parameters
module
.
old_parameters
=
module
.
parameters
funcType
=
types
.
MethodType
module
.
parameters
=
funcType
(
fake_parameters
,
module
)
module
.
named_parameters
=
funcType
(
fake_named_parameters
,
module
)
module
.
colo_parameters
=
funcType
(
colo_parameters
,
module
)
module
.
colo_named_parameters
=
funcType
(
colo_named_parameters
,
module
)
module
.
_colo_visited
=
True
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
...
@@ -24,8 +82,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -24,8 +82,11 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module.
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
FIXME(fjr) The module may be passed to this function multiple times?
"""
"""
if
hasattr
(
module
,
'_colo_visited'
):
return
name_list
=
[]
name_list
=
[]
for
name
,
param
in
module
.
named_parameters
():
for
name
,
param
in
module
.
named_parameters
(
recurse
=
False
):
if
isinstance
(
param
,
ColoTensor
):
if
isinstance
(
param
,
ColoTensor
):
continue
continue
name_list
.
append
((
name
,
param
))
name_list
.
append
((
name
,
param
))
...
@@ -35,3 +96,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -35,3 +96,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
delattr
(
module
,
name
)
delattr
(
module
,
name
)
setattr
(
module
,
name
,
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
to
(
self
.
_device
),
save_payload
=
save_torch_payload
))
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
to
(
self
.
_device
),
save_payload
=
save_torch_payload
))
ColoModulize
(
module
)
tests/test_tensor/test_model.py
View file @
26c49639
...
@@ -48,7 +48,7 @@ def run_1d_row_tp():
...
@@ -48,7 +48,7 @@ def run_1d_row_tp():
model_torch
=
model_torch
.
cuda
()
model_torch
=
model_torch
.
cuda
()
# A naive way to set spec for all weights in Linear
# A naive way to set spec for all weights in Linear
for
name
,
p
in
named_params_with_colotensor
(
model
):
for
name
,
p
in
model
.
colo_named_parameters
(
):
if
not
isinstance
(
p
,
ColoTensor
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
continue
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment