Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ab8c6b4a
Unverified
Commit
ab8c6b4a
authored
Apr 11, 2022
by
ver217
Committed by
GitHub
Apr 11, 2022
Browse files
[zero] refactor memstats collector (#706)
* refactor memstats collector * fix disposable * polish code
parent
3fc8a204
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
44 additions
and
114 deletions
+44
-114
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+2
-2
colossalai/utils/common.py
colossalai/utils/common.py
+17
-3
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+17
-44
colossalai/utils/memory_tracer/test_async_memtracer.py
colossalai/utils/memory_tracer/test_async_memtracer.py
+0
-16
colossalai/utils/memory_tracer/test_memstats_collector.py
colossalai/utils/memory_tracer/test_memstats_collector.py
+0
-37
colossalai/zero/shard_utils/stateful_tensor_mgr.py
colossalai/zero/shard_utils/stateful_tensor_mgr.py
+1
-2
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+7
-9
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
+0
-1
No files found.
colossalai/utils/__init__.py
View file @
ab8c6b4a
...
@@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
...
@@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
ensure_path_exists
,
free_port
,
is_dp_rank_0
,
is_model_parallel_parameter
,
is_no_pp_or_last_stage
,
ensure_path_exists
,
free_port
,
is_dp_rank_0
,
is_model_parallel_parameter
,
is_no_pp_or_last_stage
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
multi_tensor_applier
,
is_tp_rank_0
,
is_using_ddp
,
is_using_pp
,
is_using_sequence
,
multi_tensor_applier
,
param_is_not_tensor_parallel_duplicate
,
print_rank_0
,
switch_virtual_pipeline_parallel_rank
,
param_is_not_tensor_parallel_duplicate
,
print_rank_0
,
switch_virtual_pipeline_parallel_rank
,
sync_model_param
)
sync_model_param
,
disposable
)
from
.data_sampler
import
DataParallelSampler
,
get_dataloader
from
.data_sampler
import
DataParallelSampler
,
get_dataloader
from
.gradient_accumulation
import
accumulate_gradient
from
.gradient_accumulation
import
accumulate_gradient
from
.memory_utils.memory_monitor
import
report_memory_usage
from
.memory_utils.memory_monitor
import
report_memory_usage
...
@@ -19,5 +19,5 @@ __all__ = [
...
@@ -19,5 +19,5 @@ __all__ = [
'param_is_not_tensor_parallel_duplicate'
,
'get_current_device'
,
'synchronize'
,
'empty_cache'
,
'set_to_cuda'
,
'param_is_not_tensor_parallel_duplicate'
,
'get_current_device'
,
'synchronize'
,
'empty_cache'
,
'set_to_cuda'
,
'report_memory_usage'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'accumulate_gradient'
,
'DataParallelSampler'
,
'report_memory_usage'
,
'Timer'
,
'MultiTimer'
,
'multi_tensor_applier'
,
'accumulate_gradient'
,
'DataParallelSampler'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'get_dataloader'
,
'switch_virtual_pipeline_parallel_rank'
,
'TensorDetector'
,
'load_checkpoint'
,
'save_checkpoint'
,
'ensure_path_exists'
'ensure_path_exists'
,
'disposable'
]
]
colossalai/utils/common.py
View file @
ab8c6b4a
...
@@ -4,8 +4,8 @@ import os
...
@@ -4,8 +4,8 @@ import os
import
random
import
random
import
socket
import
socket
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Union
from
typing
import
Callable
,
List
,
Union
import
functools
import
torch
import
torch
from
torch._six
import
inf
from
torch._six
import
inf
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
@@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
...
@@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
class
model_branch_context
(
object
):
class
model_branch_context
(
object
):
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
env_status
=
env
.
save
()
self
.
env_status
=
env
.
save
()
...
@@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
...
@@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
colossal_C
.
multi_tensor_l2norm
,
colossal_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
dummy_overflow_buf
,
[
grads
],
[
grads
],
False
# no per-parameter norm
False
# no per-parameter norm
)
)
return
norm
return
norm
...
@@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
...
@@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
yield
yield
finally
:
finally
:
gpc
.
set_virtual_pipeline_parallel_rank
(
prev_rank
)
gpc
.
set_virtual_pipeline_parallel_rank
(
prev_rank
)
def
disposable
(
func
:
Callable
)
->
Callable
:
executed
=
False
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
executed
if
not
executed
:
executed
=
True
return
func
(
*
args
,
**
kwargs
)
return
wrapper
colossalai/utils/memory_tracer/memstats_collector.py
View file @
ab8c6b4a
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_device_memory_used
from
colossalai.utils.memory_utils.utils
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.async_memtracer
import
AsyncMemoryMonitor
from
colossalai.utils.memory_tracer.async_memtracer
import
AsyncMemoryMonitor
import
torch
import
torch
import
time
import
time
from
typing
import
List
from
typing
import
List
class
SamplingCounter
:
def
__init__
(
self
)
->
None
:
self
.
_samplint_cnt
=
0
self
.
_max_sampling_cnt
=
None
def
advance
(
self
):
self
.
_samplint_cnt
+=
1
def
next
(
self
):
assert
self
.
_max_sampling_cnt
is
not
None
return
(
self
.
_samplint_cnt
+
1
)
%
self
.
_max_sampling_cnt
def
current
(
self
):
return
self
.
_samplint_cnt
def
max
(
self
):
return
self
.
_max_sampling_cnt
def
reset
(
self
):
self
.
_max_sampling_cnt
=
self
.
_samplint_cnt
self
.
_samplint_cnt
=
0
class
MemStatsCollector
:
class
MemStatsCollector
:
"""
"""
A Memory statistic collector.
A Memory statistic collector.
...
@@ -44,7 +19,6 @@ class MemStatsCollector:
...
@@ -44,7 +19,6 @@ class MemStatsCollector:
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
_sampling_cnter
=
SamplingCounter
()
self
.
_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
...
@@ -57,6 +31,7 @@ class MemStatsCollector:
...
@@ -57,6 +31,7 @@ class MemStatsCollector:
self
.
_sampling_time
=
[]
self
.
_sampling_time
=
[]
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_period_idx
=
0
def
overall_mem_stats
(
self
,
device_type
:
str
):
def
overall_mem_stats
(
self
,
device_type
:
str
):
if
device_type
==
'cuda'
:
if
device_type
==
'cuda'
:
...
@@ -106,15 +81,22 @@ class MemStatsCollector:
...
@@ -106,15 +81,22 @@ class MemStatsCollector:
else
:
else
:
raise
TypeError
raise
TypeError
def
current_non_model_data
(
self
,
device_type
:
str
)
->
int
:
def
max_non_model_data
(
self
,
device_type
:
str
)
->
int
:
"""get the non model data of the current sampling moment
"""Get max non model data memory usage of current sampling period
"""
return
self
.
non_model_data_list
(
device_type
)[
self
.
_sampling_cnter
.
current
()]
def
next_non_model_data
(
self
,
device_type
:
str
):
Args:
"""get the non model data of the next sampling moment
device_type (str): device type, can be 'cpu' or 'cuda'.
Returns:
int: max non model data memory usage of current sampling period
"""
"""
return
self
.
non_model_data_list
(
device_type
)[
self
.
_sampling_cnter
.
next
()]
assert
not
self
.
_start_flag
,
'Cannot get mem stats info during collection phase.'
assert
len
(
self
.
_sampling_time
)
>
0
,
'Cannot get mem stats info before collection phase.'
next_period_idx
=
(
self
.
_period_idx
+
1
)
%
len
(
self
.
_sampling_time
)
current_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
self
.
_period_idx
]
next_non_model_data
=
self
.
non_model_data_list
(
device_type
)[
next_period_idx
]
self
.
_period_idx
=
next_period_idx
return
max
(
current_non_model_data
,
next_non_model_data
)
@
property
@
property
def
sampling_time
(
self
):
def
sampling_time
(
self
):
...
@@ -126,6 +108,7 @@ class MemStatsCollector:
...
@@ -126,6 +108,7 @@ class MemStatsCollector:
def
finish_collection
(
self
):
def
finish_collection
(
self
):
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
def
sample_memstats
(
self
)
->
None
:
def
sample_memstats
(
self
)
->
None
:
"""
"""
...
@@ -134,8 +117,6 @@ class MemStatsCollector:
...
@@ -134,8 +117,6 @@ class MemStatsCollector:
Advance the sampling cnter.
Advance the sampling cnter.
"""
"""
if
self
.
_start_flag
:
if
self
.
_start_flag
:
sampling_cnt
=
self
.
_sampling_cnter
.
current
()
assert
sampling_cnt
==
len
(
self
.
_overall_cuda_list
)
self
.
_model_data_cuda_list
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
)
self
.
_model_data_cuda_list
.
append
(
GLOBAL_MODEL_DATA_TRACER
.
cuda_usage
)
self
.
_overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_overall_cuda_list
.
append
(
self
.
_mem_monitor
.
finish
())
self
.
_non_model_data_cuda_list
.
append
(
self
.
_model_data_cuda_list
[
-
1
]
-
self
.
_overall_cuda_list
[
-
1
])
self
.
_non_model_data_cuda_list
.
append
(
self
.
_model_data_cuda_list
[
-
1
]
-
self
.
_overall_cuda_list
[
-
1
])
...
@@ -146,13 +127,6 @@ class MemStatsCollector:
...
@@ -146,13 +127,6 @@ class MemStatsCollector:
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_sampling_time
.
append
(
time
.
time
())
self
.
_mem_monitor
.
start
()
self
.
_mem_monitor
.
start
()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
self
.
_sampling_cnter
.
advance
()
def
reset_sampling_cnter
(
self
)
->
None
:
self
.
_sampling_cnter
.
reset
()
self
.
_mem_monitor
.
finish
()
def
clear
(
self
)
->
None
:
def
clear
(
self
)
->
None
:
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cuda_list
=
[]
...
@@ -162,5 +136,4 @@ class MemStatsCollector:
...
@@ -162,5 +136,4 @@ class MemStatsCollector:
self
.
_overall_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_sampling_cnter
.
reset
()
self
.
_period_idx
=
0
self
.
_mem_monitor
.
finish
()
colossalai/utils/memory_tracer/test_async_memtracer.py
deleted
100644 → 0
View file @
3fc8a204
from
async_memtracer
import
AsyncMemoryMonitor
import
torch
if
__name__
==
'__main__'
:
async_mem_monitor
=
AsyncMemoryMonitor
()
input
=
torch
.
randn
(
2
,
20
).
cuda
()
OP1
=
torch
.
nn
.
Linear
(
20
,
30
).
cuda
()
OP2
=
torch
.
nn
.
Linear
(
30
,
40
).
cuda
()
async_mem_monitor
.
start
()
output
=
OP1
(
input
)
async_mem_monitor
.
finish
()
async_mem_monitor
.
start
()
output
=
OP2
(
output
)
async_mem_monitor
.
finish
()
async_mem_monitor
.
save
(
'log.pkl'
)
colossalai/utils/memory_tracer/test_memstats_collector.py
deleted
100644 → 0
View file @
3fc8a204
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
import
torch
def
test_mem_collector
():
collector
=
MemStatsCollector
()
collector
.
start_collection
()
a
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 0
collector
.
sample_memstats
()
m_a
=
torch
.
randn
(
10
).
cuda
()
b
=
torch
.
randn
(
10
).
cuda
()
# sampling at time 1
collector
.
sample_memstats
()
a
=
b
# sampling at time 2
collector
.
sample_memstats
()
collector
.
finish_collection
()
collector
.
reset_sampling_cnter
()
# do nothing after collection, just advance sampling cnter
collector
.
sample_memstats
()
collector
.
sample_memstats
()
print
(
collector
.
overall_mem_stats
(
'cuda'
))
if
__name__
==
'__main__'
:
test_mem_collector
()
colossalai/zero/shard_utils/stateful_tensor_mgr.py
View file @
ab8c6b4a
...
@@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
...
@@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
max_cuda_non_model_data_per_period
=
cuda_capacity
*
self
.
_warmup_cuda_available_ratio
max_cuda_non_model_data_per_period
=
cuda_capacity
*
self
.
_warmup_cuda_available_ratio
else
:
else
:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period
=
max
(
self
.
_mem_stats_collector
.
current_non_model_data
(
'cuda'
),
max_cuda_non_model_data_per_period
=
self
.
_mem_stats_collector
.
max_non_model_data
(
'cuda'
)
self
.
_mem_stats_collector
.
next_non_model_data
(
'cuda'
))
total_cuda_model_data
=
cuda_capacity
-
max_cuda_non_model_data_per_period
total_cuda_model_data
=
cuda_capacity
-
max_cuda_non_model_data_per_period
avail_cuda_model_data
=
total_cuda_model_data
-
used_cuda_model_data
avail_cuda_model_data
=
total_cuda_model_data
-
used_cuda_model_data
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
ab8c6b4a
...
@@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
...
@@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
...
@@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
...
@@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
for
param
in
submodule
.
parameters
(
recurse
=
False
):
for
param
in
submodule
.
parameters
(
recurse
=
False
):
if
hasattr
(
param
,
'colo_attr'
):
if
hasattr
(
param
,
'colo_attr'
):
self
.
_stateful_tensor_mgr
.
register_stateful_param
(
param
.
colo_attr
)
self
.
_stateful_tensor_mgr
.
register_stateful_param
(
param
.
colo_attr
)
self
.
_start_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
start_collection
)
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
else
:
else
:
self
.
_memstats_collector
=
None
self
.
_memstats_collector
=
None
self
.
_stateful_tensor_mgr
=
None
self
.
_stateful_tensor_mgr
=
None
self
.
_iter_cnter
=
0
# Register hooks
# Register hooks
self
.
_ophook_list
=
[
self
.
_ophook_list
=
[
...
@@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
...
@@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
f
.
write
(
'
\n
'
)
f
.
write
(
'
\n
'
)
def
_pre_forward_operations
(
self
):
def
_pre_forward_operations
(
self
):
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
# the operation will affect the memory tracer behavior in ZeroHook
# the operation will affect the memory tracer behavior in ZeroHook
if
self
.
_memstats_collector
:
self
.
_
mem
stat
s
_collect
or
.
start_collection
()
self
.
_sta
r
t_collect
_memstats
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'colo_attr'
):
if
hasattr
(
p
,
'colo_attr'
):
...
@@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
...
@@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
ophook
.
post_iter
()
ophook
.
post_iter
()
def
_update_memstats
(
self
):
def
_update_memstats
(
self
):
if
self
.
_iter_cnter
==
0
and
self
.
_memstats_collector
:
self
.
_memstats_collector
.
finish_collection
()
if
self
.
_memstats_collector
:
if
self
.
_memstats_collector
:
self
.
_
memstats_collector
.
reset_sampling_cnter
()
self
.
_
finish_collect_memstats
()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
colo_cuda_memory_capacity
()
-
max
(
self
.
_cuda_margin_space
=
colo_cuda_memory_capacity
()
-
max
(
self
.
_memstats_collector
.
overall_mem_stats
(
'cuda'
))
self
.
_memstats_collector
.
overall_mem_stats
(
'cuda'
))
self
.
_iter_cnter
+=
1
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_post_backward_operations
(
self
)
->
None
:
def
_post_backward_operations
(
self
)
->
None
:
...
...
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
View file @
ab8c6b4a
...
@@ -55,7 +55,6 @@ def run_stm():
...
@@ -55,7 +55,6 @@ def run_stm():
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_memstats
()
mem_collector
.
sample_memstats
()
mem_collector
.
finish_collection
()
mem_collector
.
finish_collection
()
mem_collector
.
reset_sampling_cnter
()
stateful_tensor_mgr
.
reset
()
stateful_tensor_mgr
.
reset
()
# warmup done
# warmup done
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment