Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8789850e
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "f83ea813f5bb6ecc5597c7f1bf97870d46de1c49"
Unverified
Commit
8789850e
authored
Apr 22, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 22, 2022
Browse files
Init Conext supports lazy allocate model memory (#842)
parent
4575a329
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
112 additions
and
13 deletions
+112
-13
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+4
-3
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+41
-10
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+40
-0
colossalai/utils/model/utils.py
colossalai/utils/model/utils.py
+0
-0
tests/test_tensor/test_context.py
tests/test_tensor/test_context.py
+27
-0
No files found.
colossalai/tensor/colo_tensor.py
View file @
8789850e
...
@@ -43,9 +43,10 @@ class ColoTensor(object):
...
@@ -43,9 +43,10 @@ class ColoTensor(object):
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
))
torch_tensor
=
tensor
if
save_payload
else
torch
.
empty
(
0
))
return
colo_t
return
colo_t
def
del_torch_tensor
(
self
)
->
None
:
def
del_torch_tensor
(
self
,
save_shape
=
False
)
->
None
:
self
.
_size
=
(
0
,)
if
save_shape
:
self
.
_torch_tensor
=
torch
.
empty
(
self
.
_size
)
self
.
_size
=
(
0
,)
self
.
_torch_tensor
=
torch
.
empty
((
0
,))
def
torch_tensor
(
self
)
->
torch
.
Tensor
:
def
torch_tensor
(
self
)
->
torch
.
Tensor
:
if
self
.
_torch_tensor
.
numel
()
==
0
:
if
self
.
_torch_tensor
.
numel
()
==
0
:
...
...
colossalai/utils/__init__.py
View file @
8789850e
...
@@ -11,16 +11,47 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
...
@@ -11,16 +11,47 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc
colo_device_memory_capacity
,
colo_set_cpu_memory_capacity
,
colo_get_cpu_memory_capacity
)
colo_device_memory_capacity
,
colo_set_cpu_memory_capacity
,
colo_get_cpu_memory_capacity
)
from
.timer
import
MultiTimer
,
Timer
from
.timer
import
MultiTimer
,
Timer
from
.tensor_detector
import
TensorDetector
from
.tensor_detector
import
TensorDetector
from
.model.init_context
import
InsertPostInitMethodToModuleSubClasses
from
.model.utils
import
InsertPostInitMethodToModuleSubClasses
from
.model.colo_init_context
import
ColoInitContext
__all__
=
[
__all__
=
[
'checkpoint'
,
'free_port'
,
'print_rank_0'
,
'sync_model_param'
,
'is_dp_rank_0'
,
'is_tp_rank_0'
,
'checkpoint'
,
'is_no_pp_or_last_stage'
,
'is_using_ddp'
,
'is_using_pp'
,
'is_using_sequence'
,
'conditional_context'
,
'free_port'
,
'is_model_parallel_parameter'
,
'clip_grad_norm_fp32'
,
'count_zeros_fp32'
,
'copy_tensor_parallel_attributes'
,
'print_rank_0'
,
'param_is_not_tensor_parallel_duplicate'
,
'get_current_device'
,
'synchronize'
,
'empty_cache'
,
'set_to_cuda'
,
'sync_model_param'
,
'report_memory_usage'
,
'colo_device_memory_capacity'
,
'colo_device_memory_used'
,
'colo_set_process_memory_fraction'
,
'is_dp_rank_0'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'DataParallelSampler'
,
'get_dataloader'
,
'is_tp_rank_0'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'is_no_pp_or_last_stage'
,
'ensure_path_exists'
,
'disposable'
,
'colo_set_cpu_memory_capacity'
,
'colo_get_cpu_memory_capacity'
,
'is_using_ddp'
,
'InsertPostInitMethodToModuleSubClasses'
'is_using_pp'
,
'is_using_sequence'
,
'conditional_context'
,
'is_model_parallel_parameter'
,
'clip_grad_norm_fp32'
,
'count_zeros_fp32'
,
'copy_tensor_parallel_attributes'
,
'param_is_not_tensor_parallel_duplicate'
,
'get_current_device'
,
'synchronize'
,
'empty_cache'
,
'set_to_cuda'
,
'report_memory_usage'
,
'colo_device_memory_capacity'
,
'colo_device_memory_used'
,
'colo_set_process_memory_fraction'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'DataParallelSampler'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'ensure_path_exists'
,
'disposable'
,
'colo_set_cpu_memory_capacity'
,
'colo_get_cpu_memory_capacity'
,
'InsertPostInitMethodToModuleSubClasses'
,
'ColoInitContext'
,
]
]
colossalai/utils/model/colo_init_context.py
0 → 100644
View file @
8789850e
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
# from colossalai.logging import get_dist_logger
from
colossalai.tensor
import
ColoTensor
# _orig_torch_empty = torch.empty
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
def
__init__
(
self
,
lazy_memory_allocate
=
False
):
super
().
__init__
()
self
.
_lazy_memory_allocate
=
lazy_memory_allocate
def
_pre_context_exec
(
self
):
"""
The Callback function when entering the context
"""
pass
def
_post_context_exec
(
self
):
"""The callback function when exiting context.
"""
pass
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
):
"""
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?
"""
name_list
=
[]
for
name
,
param
in
module
.
named_parameters
():
if
isinstance
(
param
,
ColoTensor
):
continue
name_list
.
append
((
name
,
param
))
save_torch_payload
=
True
if
not
self
.
_lazy_memory_allocate
else
False
for
name
,
param
in
name_list
:
delattr
(
module
,
name
)
setattr
(
module
,
name
,
ColoTensor
.
init_from_torch_tensor
(
tensor
=
param
.
data
,
save_payload
=
save_torch_payload
))
colossalai/utils/model/
init_context
.py
→
colossalai/utils/model/
utils
.py
View file @
8789850e
File moved
tests/test_tensor/test_context.py
0 → 100644
View file @
8789850e
from
colossalai.utils
import
ColoInitContext
from
numpy
import
allclose
,
require
import
torch
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
def
test_linear
():
in_dim
=
4
out_dim
=
5
with
ColoInitContext
(
lazy_memory_allocate
=
True
)
as
ctx
:
fc
=
torch
.
nn
.
Linear
(
in_dim
,
out_dim
,
bias
=
True
)
print
(
fc
.
weight
.
numel
())
print
(
fc
.
bias
.
numel
())
# lazy_memory_allocate=True, no payload is maintained
assert
fc
.
weight
.
_torch_tensor
.
numel
()
==
0
fc
.
weight
.
torch_tensor
()
assert
fc
.
weight
.
_torch_tensor
.
numel
()
==
in_dim
*
out_dim
if
__name__
==
'__main__'
:
test_linear
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment