Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d01d3b8c
Unverified
Commit
d01d3b8c
authored
Apr 25, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 25, 2022
Browse files
colo init context add device attr. (#866)
parent
2238758c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
12 deletions
+36
-12
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+9
-5
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+10
-2
tests/test_tensor/test_context.py
tests/test_tensor/test_context.py
+17
-5
No files found.
colossalai/tensor/colo_tensor.py
View file @
d01d3b8c
...
...
@@ -58,6 +58,10 @@ class ColoTensor(object):
def
shape
(
self
):
return
torch
.
Size
(
self
.
_size
)
@
property
def
device
(
self
):
return
self
.
_torch_tensor
.
device
def
size
(
self
,
dim
=
None
):
if
dim
is
None
:
return
self
.
shape
...
...
@@ -105,7 +109,7 @@ class ColoTensor(object):
device
=
self
.
_device
)
return
self
.
_torch_tensor
def
set_spec
(
self
,
spec
:
str
,
lazy_shard
:
bool
=
False
)
->
None
:
def
set_spec
(
self
,
spec
:
str
,
lazy_shard
:
bool
=
False
)
->
None
:
self
.
_shard_spec
=
spec
if
lazy_shard
==
False
:
self
.
_shard
()
...
...
@@ -121,8 +125,8 @@ class ColoTensor(object):
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
).
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
(
).
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
self
.
_size
=
self
.
_torch_tensor
.
size
()
self
.
_device
=
device
# TODO A `fake` device now because torch_tensor.device always = cpu
...
...
colossalai/utils/model/colo_init_context.py
View file @
d01d3b8c
from
colossalai.utils.cuda
import
get_current_device
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
# from colossalai.logging import get_dist_logger
...
...
@@ -8,9 +9,15 @@ from colossalai.tensor import ColoTensor
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
def
__init__
(
self
,
lazy_memory_allocate
=
False
):
def
__init__
(
self
,
lazy_memory_allocate
:
bool
=
False
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
)):
"""
Args:
lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False.
device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu').
"""
super
().
__init__
()
self
.
_lazy_memory_allocate
=
lazy_memory_allocate
self
.
_device
=
device
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
"""
...
...
@@ -26,4 +33,5 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
save_torch_payload
=
True
if
not
self
.
_lazy_memory_allocate
else
False
for
name
,
param
in
name_list
:
delattr
(
module
,
name
)
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
,
save_payload
=
save_torch_payload
))
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
to
(
self
.
_device
),
save_payload
=
save_torch_payload
))
tests/test_tensor/test_context.py
View file @
d01d3b8c
...
...
@@ -5,17 +5,16 @@ import torch
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
from
colossalai.utils.cuda
import
get_current_device
def
test_linear
():
def
test_lazy_init
():
in_dim
=
4
out_dim
=
5
with
ColoInitContext
(
lazy_memory_allocate
=
True
)
as
ctx
:
fc
=
torch
.
nn
.
Linear
(
in_dim
,
out_dim
,
bias
=
True
)
print
(
fc
.
weight
.
numel
())
print
(
fc
.
bias
.
numel
())
# lazy_memory_allocate=True, no payload is maintained
assert
fc
.
weight
.
_torch_tensor
.
numel
()
==
0
...
...
@@ -23,5 +22,18 @@ def test_linear():
assert
fc
.
weight
.
_torch_tensor
.
numel
()
==
in_dim
*
out_dim
def
test_device
():
in_dim
=
4
out_dim
=
5
with
ColoInitContext
(
lazy_memory_allocate
=
True
,
device
=
get_current_device
())
as
ctx
:
fc
=
torch
.
nn
.
Linear
(
in_dim
,
out_dim
,
bias
=
True
)
# eval an lazy parameter
fc
.
weight
.
torch_tensor
()
assert
fc
.
weight
.
device
==
get_current_device
()
if
__name__
==
'__main__'
:
test_linear
()
test_lazy_init
()
test_device
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment