Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2938edf4
Unverified
Commit
2938edf4
authored
Dec 13, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 13, 2022
Browse files
[Gemini] update the non model data record method in runtime memory tracer (#2128)
parent
deee317b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
65 additions
and
56 deletions
+65
-56
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+2
-2
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
+1
-1
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+40
-37
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+1
-1
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+3
-1
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+9
-11
colossalai/zero/utils/gemini_hook.py
colossalai/zero/utils/gemini_hook.py
+1
-1
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+1
-1
tests/test_gemini/update/test_gemini_use_rmt.py
tests/test_gemini/update/test_gemini_use_rmt.py
+7
-1
No files found.
colossalai/gemini/gemini_mgr.py
View file @
2938edf4
...
@@ -133,9 +133,9 @@ class GeminiManager:
...
@@ -133,9 +133,9 @@ class GeminiManager:
if
self
.
_mem_stats_collector
:
if
self
.
_mem_stats_collector
:
self
.
_mem_stats_collector
.
sample_overall_data
()
self
.
_mem_stats_collector
.
sample_overall_data
()
def
sample
_model_data
(
self
):
def
record
_model_data
_volume
(
self
):
if
self
.
_mem_stats_collector
:
if
self
.
_mem_stats_collector
:
self
.
_mem_stats_collector
.
sample
_model_data
()
self
.
_mem_stats_collector
.
record
_model_data
_volume
()
@
property
@
property
def
chunk_manager
(
self
):
def
chunk_manager
(
self
):
...
...
colossalai/gemini/memory_tracer/chunk_memstats_collector.py
View file @
2938edf4
...
@@ -15,7 +15,7 @@ class ChunkMemStatsCollector(MemStatsCollector):
...
@@ -15,7 +15,7 @@ class ChunkMemStatsCollector(MemStatsCollector):
self
.
_chunk_manager
=
chunk_manager
self
.
_chunk_manager
=
chunk_manager
# override
# override
def
sample
_model_data
(
self
)
->
None
:
def
record
_model_data
_volume
(
self
)
->
None
:
"""Sampling model data statistics.
"""Sampling model data statistics.
"""
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
...
...
colossalai/gemini/memory_tracer/memory_stats.py
View file @
2938edf4
...
@@ -15,7 +15,7 @@ class MemStats(object):
...
@@ -15,7 +15,7 @@ class MemStats(object):
self
.
_step_param_dict
=
dict
()
self
.
_step_param_dict
=
dict
()
# (param, List[preop_step])
# (param, List[preop_step])
self
.
_param_step_dict
=
dict
()
self
.
_param_step_dict
=
dict
()
# (preop_step, non_model_data)
# (preop_step, non_model_data)
non model data used during preop_step ~ (preop_step+1)
self
.
_step_nmd_dict
=
dict
()
self
.
_step_nmd_dict
=
dict
()
self
.
_param_runtime_order
=
OrderedParamGenerator
()
self
.
_param_runtime_order
=
OrderedParamGenerator
()
...
@@ -23,9 +23,8 @@ class MemStats(object):
...
@@ -23,9 +23,8 @@ class MemStats(object):
self
.
_prev_overall_cuda
=
-
1
self
.
_prev_overall_cuda
=
-
1
self
.
_prev_md_cuda
=
-
1
self
.
_prev_md_cuda
=
-
1
# old version
self
.
param_non_model_data_map
:
Dict
(
Any
,
List
[
int
])
=
{}
# old version
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_model_data_cpu_list
=
[]
...
@@ -35,9 +34,12 @@ class MemStats(object):
...
@@ -35,9 +34,12 @@ class MemStats(object):
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
def
record
_max_cuda_non_model_data
(
self
):
def
calc
_max_cuda_non_model_data
(
self
):
if
self
.
_prev_overall_cuda
!=
-
1
and
self
.
_prev_md_cuda
!=
-
1
:
if
self
.
_prev_overall_cuda
!=
-
1
and
self
.
_prev_md_cuda
!=
-
1
:
self
.
_step_nmd_dict
[
self
.
_preop_step
]
=
self
.
_prev_overall_cuda
-
self
.
_prev_md_cuda
max_cuda_non_model_data
=
self
.
_prev_overall_cuda
-
self
.
_prev_md_cuda
self
.
_step_nmd_dict
[
self
.
_preop_step
-
1
]
=
max_cuda_non_model_data
# compatibility of the old version.
self
.
_non_model_data_cuda_list
.
append
(
max_cuda_non_model_data
)
def
record_max_cuda_model_data
(
self
,
val
):
def
record_max_cuda_model_data
(
self
,
val
):
self
.
_prev_md_cuda
=
val
self
.
_prev_md_cuda
=
val
...
@@ -45,12 +47,45 @@ class MemStats(object):
...
@@ -45,12 +47,45 @@ class MemStats(object):
def
record_max_cuda_overall_data
(
self
,
val
):
def
record_max_cuda_overall_data
(
self
,
val
):
self
.
_prev_overall_cuda
=
val
self
.
_prev_overall_cuda
=
val
def
increase_preop_step
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for
p
in
param_list
:
if
p
not
in
self
.
_param_step_dict
:
self
.
_param_step_dict
[
p
]
=
[
self
.
_preop_step
]
else
:
self
.
_param_step_dict
[
p
].
append
(
self
.
_preop_step
)
self
.
_param_runtime_order
.
append
(
p
)
self
.
_step_param_dict
[
self
.
_preop_step
]
=
param_list
self
.
_preop_step
+=
1
def
param_used_step
(
self
,
param
:
torch
.
nn
.
Parameter
)
->
Optional
[
List
[
int
]]:
"""param_used_step
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if
param
not
in
self
.
_param_step_dict
:
return
None
else
:
return
self
.
_param_step_dict
[
param
]
def
param_order
(
self
):
def
param_order
(
self
):
if
self
.
_param_runtime_order
.
is_empty
():
if
self
.
_param_runtime_order
.
is_empty
():
raise
RuntimeError
raise
RuntimeError
else
:
else
:
return
self
.
_param_runtime_order
return
self
.
_param_runtime_order
## APIs to be depracated
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
if
device_type
==
'cuda'
:
self
.
_overall_cuda_list
.
append
(
val
)
self
.
_overall_cuda_list
.
append
(
val
)
...
@@ -135,38 +170,6 @@ class MemStats(object):
...
@@ -135,38 +170,6 @@ class MemStats(object):
else
:
else
:
raise
TypeError
raise
TypeError
def
increase_preop_step
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for
p
in
param_list
:
if
p
not
in
self
.
_param_step_dict
:
self
.
_param_step_dict
[
p
]
=
[
self
.
_preop_step
]
else
:
self
.
_param_step_dict
[
p
].
append
(
self
.
_preop_step
)
self
.
_param_runtime_order
.
append
(
p
)
self
.
_step_param_dict
[
self
.
_preop_step
]
=
param_list
self
.
_preop_step
+=
1
def
param_used_timestep
(
self
,
param
:
torch
.
nn
.
Parameter
)
->
Optional
[
List
[
int
]]:
"""param_used_timestep
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if
param
not
in
self
.
_param_step_dict
:
return
None
else
:
return
self
.
_param_step_dict
[
param
]
def
clear
(
self
):
def
clear
(
self
):
self
.
_model_data_cuda_list
=
[]
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
...
...
colossalai/gemini/memory_tracer/memstats_collector.py
View file @
2938edf4
...
@@ -69,7 +69,7 @@ class MemStatsCollector:
...
@@ -69,7 +69,7 @@ class MemStatsCollector:
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_mem_monitor
.
finish
()
self
.
_mem_monitor
.
finish
()
def
sample
_model_data
(
self
)
->
None
:
def
record
_model_data
_volume
(
self
)
->
None
:
"""Sampling model data statistics.
"""Sampling model data statistics.
"""
"""
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
if
self
.
_start_flag
and
not
self
.
use_outside_memstats
:
...
...
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
2938edf4
...
@@ -82,7 +82,9 @@ class RuntimeMemTracer():
...
@@ -82,7 +82,9 @@ class RuntimeMemTracer():
def
_post_backward
(
self
):
def
_post_backward
(
self
):
cuda_volume
=
self
.
param_op_hook
.
mem_monitor
.
finish
()
cuda_volume
=
self
.
param_op_hook
.
mem_monitor
.
finish
()
self
.
_memstats
.
append_non_model_data
(
'cuda'
,
cuda_volume
-
self
.
_memstats
.
last_model_data
(
'cuda'
))
self
.
_memstats
.
record_max_cuda_overall_data
(
cuda_volume
)
# calc the last Op non model data
self
.
_memstats
.
calc_max_cuda_non_model_data
()
self
.
grad_hook
.
remove_grad_hook
()
self
.
grad_hook
.
remove_grad_hook
()
self
.
_restore_params
()
self
.
_restore_params
()
...
...
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
View file @
2938edf4
...
@@ -86,7 +86,7 @@ class ParamMemTracerHook(ColoParamOpHook):
...
@@ -86,7 +86,7 @@ class ParamMemTracerHook(ColoParamOpHook):
elif
cur_dev
==
"cuda"
:
elif
cur_dev
==
"cuda"
:
alloc_storage
(
p
.
data
)
alloc_storage
(
p
.
data
)
def
sample
_model_data
(
self
,
params
):
def
record
_model_data
_volume
(
self
,
params
):
"""
"""
get cuda model data used by params
get cuda model data used by params
"""
"""
...
@@ -100,21 +100,19 @@ class ParamMemTracerHook(ColoParamOpHook):
...
@@ -100,21 +100,19 @@ class ParamMemTracerHook(ColoParamOpHook):
if
not
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]:
if
not
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]:
self
.
_grad_stats
.
unreleased_grad_volume
+=
cur_model_data_volume
self
.
_grad_stats
.
unreleased_grad_volume
+=
cur_model_data_volume
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
True
self
.
_grad_stats
.
unreleased_grad_flag
[
p
]
=
True
self
.
_memstats
.
append_model_data
(
'cuda'
,
data_volume
)
# record max non model data used for this Op
# record max non model data used for this Op
self
.
_memstats
.
record_max_cuda_model_data
(
data_volume
)
self
.
_memstats
.
record_max_cuda_model_data
(
data_volume
)
def
pre_op
(
self
,
params
):
def
pre_op
(
self
,
params
):
# get overall cuda data.
max_cuda_used_pre_op
=
self
.
mem_monitor
.
finish
()
max_cuda_vol_of_period
=
self
.
mem_monitor
.
finish
()
# record max cuda overall data for prev OP.
# record max cuda overall data for prev Op.
self
.
_memstats
.
record_max_cuda_overall_data
(
max_cuda_used_pre_op
)
self
.
_memstats
.
record_max_cuda_overall_data
(
max_cuda_vol_of_period
)
# record max cuda non model data for prev OP.
self
.
_memstats
.
record_max_cuda_non_model_data
()
self
.
_memstats
.
calc_max_cuda_non_model_data
()
max_cuda_model_data_val
=
self
.
_memstats
.
last_model_data
(
'cuda'
)
if
max_cuda_model_data_val
is
not
None
:
self
.
_memstats
.
append_non_model_data
(
'cuda'
,
max_cuda_vol_of_period
-
max_cuda_model_data_val
)
self
.
_allocate_params_on_cuda
(
params
)
self
.
_allocate_params_on_cuda
(
params
)
self
.
sample_model_data
(
params
)
# record max cuda model data for current OP
self
.
record_model_data_volume
(
params
)
self
.
mem_monitor
.
start
()
self
.
mem_monitor
.
start
()
self
.
_memstats
.
increase_preop_step
(
params
)
self
.
_memstats
.
increase_preop_step
(
params
)
...
...
colossalai/zero/utils/gemini_hook.py
View file @
2938edf4
...
@@ -32,7 +32,7 @@ class GeminiZeROHook(ColoParamOpHook):
...
@@ -32,7 +32,7 @@ class GeminiZeROHook(ColoParamOpHook):
self
.
_gemini_manager
.
adjust_layout
(
chunks
)
self
.
_gemini_manager
.
adjust_layout
(
chunks
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
self
.
_chunk_manager
.
access_chunk
(
chunk
)
self
.
_chunk_manager
.
access_chunk
(
chunk
)
self
.
_gemini_manager
.
sample
_model_data
()
self
.
_gemini_manager
.
record
_model_data
_volume
()
def
post_op
(
self
,
params
):
def
post_op
(
self
,
params
):
params
=
[
p
for
p
in
params
if
not
getattr
(
p
,
'_ddp_to_ignore'
,
False
)]
params
=
[
p
for
p
in
params
if
not
getattr
(
p
,
'_ddp_to_ignore'
,
False
)]
...
...
colossalai/zero/utils/zero_hook.py
View file @
2938edf4
...
@@ -67,7 +67,7 @@ class ZeroHook(BaseOpHook):
...
@@ -67,7 +67,7 @@ class ZeroHook(BaseOpHook):
# record model data statistics
# record model data statistics
if
self
.
_memstarts_collector
:
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample
_model_data
()
self
.
_memstarts_collector
.
record
_model_data
_volume
()
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
self
.
adjust_module_data
(
module
)
self
.
adjust_module_data
(
module
)
...
...
tests/test_gemini/update/test_gemini_use_rmt.py
View file @
2938edf4
...
@@ -47,7 +47,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
...
@@ -47,7 +47,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
runtime_tracer_non_model_data
=
runtime_mem_tracer
.
_memstats
.
_non_model_data_cuda_list
runtime_tracer_non_model_data
=
runtime_mem_tracer
.
_memstats
.
_non_model_data_cuda_list
print
(
'runtime tracer non model data points: '
,
len
(
runtime_tracer_non_model_data
))
print
(
'runtime tracer non model data points: '
,
len
(
runtime_tracer_non_model_data
))
print
(
'runtime tracer: '
,
runtime_tracer_non_model_data
)
print
(
'runtime tracer: '
,
runtime_tracer_non_model_data
)
print
([
memstats
.
param_used_timestep
(
p
)
for
p
in
model
.
parameters
()])
print
([
memstats
.
param_used_step
(
p
)
for
p
in
model
.
parameters
()])
if
model_name
==
'repeated_computed_layers'
:
for
idx
,
p
in
enumerate
(
model
.
parameters
()):
step_list
=
memstats
.
param_used_step
(
p
)
if
idx
<
4
:
assert
len
(
step_list
)
==
4
if
model_name
==
'repeated_computed_layers'
:
if
model_name
==
'repeated_computed_layers'
:
for
idx
,
p
in
enumerate
(
model
.
parameters
()):
for
idx
,
p
in
enumerate
(
model
.
parameters
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment