Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
59bf2dc5
Unverified
Commit
59bf2dc5
authored
Apr 06, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 06, 2022
Browse files
[zero] initialize a stateful tensor manager (#614)
parent
cc236916
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
3 deletions
+93
-3
colossalai/utils/memory_tracer/__init__.py
colossalai/utils/memory_tracer/__init__.py
+2
-1
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+16
-2
colossalai/zero/shard_utils/stateful_tensor_mgr.py
colossalai/zero/shard_utils/stateful_tensor_mgr.py
+69
-0
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+6
-0
No files found.
colossalai/utils/memory_tracer/__init__.py
View file @
59bf2dc5
from
.async_memtracer
import
AsyncMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
__all__
=
[
'AsyncMemoryMonitor'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'MemStatsCollector'
]
colossalai/utils/memory_tracer/memstats_collector.py
View file @
59bf2dc5
...
...
@@ -11,15 +11,21 @@ class SamplingCounter:
def
__init__
(
self
)
->
None
:
self
.
_samplint_cnt
=
0
self
.
_max_sampling_cnt
=
None
def
advance
(
self
):
self
.
_samplint_cnt
+=
1
def
next
(
self
):
assert
self
.
_max_sampling_cnt
is
not
None
return
(
self
.
_samplint_cnt
+
1
)
%
self
.
_max_sampling_cnt
@
property
def
sampling_cnt
(
self
):
return
self
.
_samplint_cnt
def
reset
(
self
):
self
.
_max_sampling_cnt
=
self
.
_samplint_cnt
self
.
_samplint_cnt
=
0
...
...
@@ -56,7 +62,7 @@ class MemStatsCollector:
else
:
raise
TypeError
def
model_data_
cuda_
list
(
self
,
device_type
:
str
,
unit
:
str
=
'B'
)
->
List
[
int
]:
def
model_data_list
(
self
,
device_type
:
str
,
unit
:
str
=
'B'
)
->
List
[
int
]:
if
unit
==
'GB'
:
scale
=
1e9
elif
unit
==
'MB'
:
...
...
@@ -75,7 +81,7 @@ class MemStatsCollector:
else
:
raise
TypeError
def
non_model_data_
cuda_
list
(
self
,
device_type
:
str
,
unit
:
str
=
'B'
)
->
List
[
int
]:
def
non_model_data_list
(
self
,
device_type
:
str
,
unit
:
str
=
'B'
)
->
List
[
int
]:
"""Non model data stats
"""
if
unit
==
'GB'
:
...
...
@@ -96,6 +102,14 @@ class MemStatsCollector:
else
:
raise
TypeError
def
current_non_model_data
(
self
,
device_type
:
str
)
->
int
:
"""get the non model data of current sampling moment
"""
return
self
.
non_model_data_list
(
device_type
)[
self
.
_sampling_cnter
.
sampling_cnt
]
def
next_non_model_data
(
self
,
device_type
:
str
):
return
self
.
non_model_data_list
(
device_type
)[
self
.
_sampling_cnter
.
next
()]
@
property
def
sampling_time
(
self
):
return
[
t
-
self
.
_sampling_time
[
0
]
for
t
in
self
.
_sampling_time
]
...
...
colossalai/zero/shard_utils/stateful_tensor_mgr.py
0 → 100644
View file @
59bf2dc5
import
torch
from
colossalai.context.singleton_meta
import
SingletonMeta
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
colossalai.zero.shard_utils.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.utils.memory_utils.utils
import
colo_cuda_memory_capacity
from
typing
import
Set
from
colossalai.utils.memory_tracer
import
MemStatsCollector
class
StatefulTensorMgr
(
SingletonMeta
):
_stateful_tensor_list
:
Set
[
ShardedParamV2
]
=
set
()
def
register_param
(
self
,
param
:
ShardedParamV2
)
->
None
:
for
t
in
param
.
get_payload_tensors
():
assert
isinstance
(
t
,
StatefulTensor
)
self
.
_stateful_tensor_list
.
add
(
t
)
def
evict_tensors
(
self
)
->
None
:
pass
def
adjust_layout
(
self
,
mem_stats_collector
:
MemStatsCollector
)
->
None
:
""" Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
Args:
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
It contains non-model footprint of a DNN model.
"""
# find stateful tensor in state COMPUTE
move_to_cuda_tensor_list
=
[]
cuda_demand
=
0
used_cuda_model_data
=
0
hold_cuda_tensor_list
=
[]
for
tensor
in
self
.
_stateful_tensor_list
:
if
tensor
.
state
==
TensorState
.
FREE
:
continue
if
tensor
.
device
.
type
==
'cuda'
:
used_cuda_model_data
+=
colo_tensor_mem_usage
(
tensor
.
payload
)[
0
]
if
tensor
.
state
in
[
TensorState
.
HOLD
,
TensorState
.
HOLD_AFTER_BWD
,
TensorState
.
HOLD_AFTER_FWD
]:
hold_cuda_tensor_list
.
append
(
tensor
)
else
:
if
tensor
.
state
==
TensorState
.
COMPUTE
:
move_to_cuda_tensor_list
.
append
(
tensor
)
cuda_demand
+=
colo_tensor_mem_usage
(
tensor
.
payload
)[
0
]
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period
=
max
(
mem_stats_collector
.
current_non_model_data
(
'cuda'
),
mem_stats_collector
.
next_non_model_data
(
'cuda'
))
cuda_capacity
=
colo_cuda_memory_capacity
()
cuda_model_data_period
=
cuda_capacity
-
max_cuda_non_model_data_per_period
if
cuda_model_data_period
<
used_cuda_model_data
+
cuda_demand
:
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
# Here use a naive eviction strategy.
acc_size
=
0
for
t
in
hold_cuda_tensor_list
:
if
acc_size
>
cuda_demand
:
break
colo_model_data_tensor_move_inline
(
t
,
torch
.
device
(
'cpu'
))
t_size
=
colo_tensor_mem_usage
(
t
)
acc_size
+=
t_size
if
acc_size
<
cuda_demand
:
raise
RuntimeError
(
"Adjust layout failed! No enough CUDA memory!"
)
# move COMPUTE tensors to CUDA
for
t
in
move_to_cuda_tensor_list
:
colo_model_data_tensor_move_inline
(
t
,
get_current_device
())
colossalai/zero/sharded_param/sharded_param.py
View file @
59bf2dc5
...
...
@@ -3,6 +3,7 @@ from colossalai.zero.sharded_param import ShardedTensor
from
typing
import
Optional
,
Tuple
from
colossalai.zero.shard_utils.tensor_utils
import
colo_tensor_mem_usage
from
.tensorful_state
import
StatefulTensor
,
TensorState
from
typing
import
List
class
ShardedParamV2
(
object
):
...
...
@@ -22,6 +23,11 @@ class ShardedParamV2(object):
if
rm_torch_payload
:
self
.
remove_torch_payload
()
def
get_payload_tensors
(
self
)
->
List
[
StatefulTensor
]:
"""returns stateful tensors kept by this class.
"""
return
[
self
.
_sharded_data_tensor
,
self
.
saved_grad
]
def
remove_torch_payload
(
self
):
self
.
param
.
data
=
torch
.
empty
([],
dtype
=
self
.
param
.
dtype
,
device
=
self
.
param
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment