Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
05bb28aa
Unverified
Commit
05bb28aa
authored
Dec 13, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 13, 2022
Browse files
[Gemini] mapping of preop timestep and param (#2124)
parent
764bc16f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
6 deletions
+49
-6
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+46
-1
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+1
-4
tests/test_gemini/update/test_gemini_use_rmt.py
tests/test_gemini/update/test_gemini_use_rmt.py
+2
-1
No files found.
colossalai/gemini/memory_tracer/memory_stats.py
View file @
05bb28aa
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
colossalai.gemini.memory_tracer
import
OrderedParamGenerator
...
...
@@ -10,6 +12,12 @@ class MemStats(object):
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
# (preop_moment, List[param])
self
.
_step_param_dict
=
dict
()
self
.
_param_step_dict
=
dict
()
# (param, List[preop_moment])
self
.
param_non_model_data_map
:
Dict
(
Any
,
List
[
int
])
=
{}
self
.
_model_data_cuda_list
=
[]
...
...
@@ -23,6 +31,8 @@ class MemStats(object):
self
.
_param_runtime_order
=
OrderedParamGenerator
()
self
.
_preop_step
=
0
def
param_order
(
self
):
if
self
.
_param_runtime_order
.
is_empty
():
raise
RuntimeError
...
...
@@ -113,6 +123,38 @@ class MemStats(object):
else
:
raise
TypeError
def
increase_preop_step
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for
p
in
param_list
:
if
p
not
in
self
.
_param_step_dict
:
self
.
_param_step_dict
[
p
]
=
[
self
.
_preop_step
]
else
:
self
.
_param_step_dict
[
p
].
append
(
self
.
_preop_step
)
self
.
_param_runtime_order
.
append
(
p
)
self
.
_step_param_dict
[
self
.
_preop_step
]
=
param_list
self
.
_preop_step
+=
1
def
param_used_timestep
(
self
,
param
:
torch
.
nn
.
Parameter
)
->
Optional
[
List
[
int
]]:
"""param_used_timestep
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if
param
not
in
self
.
_param_step_dict
:
return
None
else
:
return
self
.
_param_step_dict
[
param
]
def
clear
(
self
):
self
.
_model_data_cuda_list
=
[]
self
.
_overall_cuda_list
=
[]
...
...
@@ -124,3 +166,6 @@ class MemStats(object):
self
.
_non_model_data_cuda_list
=
[]
self
.
_param_runtime_order
.
clear
()
self
.
_step_param_dict
.
clear
()
self
.
_param_step_dict
.
clear
()
self
.
_preop_step
=
0
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
View file @
05bb28aa
...
...
@@ -98,10 +98,7 @@ class ParamMemTracerHook(ColoParamOpHook):
self
.
_allocate_params_on_cuda
(
params
)
self
.
sample_model_data
(
params
)
self
.
mem_monitor
.
start
()
# register the order of visited.
for
p
in
params
:
self
.
_memstats
.
_param_runtime_order
.
append
(
p
)
self
.
_memstats
.
increase_preop_step
(
params
)
def
post_op
(
self
,
params
):
self
.
_free_cuda_params
(
params
)
...
...
tests/test_gemini/update/test_gemini_use_rmt.py
View file @
05bb28aa
...
...
@@ -45,7 +45,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
run_fwd_bwd
(
runtime_mem_tracer
,
input_ids
,
label
,
criterion
,
runtime_mem_tracer
)
memstats
=
runtime_mem_tracer
.
memstats
()
runtime_tracer_non_model_data
=
runtime_mem_tracer
.
_memstats
.
_non_model_data_cuda_list
print
(
'runtime tracer non model data points: '
,
len
(
runtime_tracer_non_model_data
))
print
(
'runtime tracer: '
,
runtime_tracer_non_model_data
)
print
([
memstats
.
param_used_timestep
(
p
)
for
p
in
model
.
parameters
()])
model
=
GeminiDDP
(
model
,
device
=
'cuda'
,
placement_policy
=
placement_policy
,
search_range_mb
=
1
,
memstats
=
memstats
)
zero_optim
=
GeminiAdamOptimizer
(
model
,
lr
=
1e-3
,
initial_scale
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment