Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4b055351
Unverified
Commit
4b055351
authored
Dec 07, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 07, 2022
Browse files
[Gemini] make RuntimeMemTracer work correctly (#2096)
parent
fa9d1aea
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
10 deletions
+29
-10
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+23
-5
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+1
-2
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+5
-3
No files found.
colossalai/gemini/memory_tracer/memory_stats.py
View file @
4b055351
...
@@ -35,13 +35,31 @@ class MemStats(object):
...
@@ -35,13 +35,31 @@ class MemStats(object):
else
:
else
:
raise
TypeError
raise
TypeError
def
append_non_model_data
(
self
,
device_type
:
str
):
def
last_model_data
(
self
,
device_type
:
str
):
if
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
None
if
device_type
==
'cuda'
:
return
self
.
_model_data_cuda_list
[
-
1
]
elif
device_type
==
'cpu'
:
return
self
.
_model_data_cpu_list
[
-
1
]
else
:
raise
TypeError
def
append_non_model_data
(
self
,
device_type
:
str
,
val
=
None
):
if
device_type
==
'cuda'
:
if
val
is
None
:
if
len
(
self
.
_overall_cuda_list
)
==
0
or
len
(
self
.
_model_data_cuda_list
)
==
0
:
if
len
(
self
.
_overall_cuda_list
)
==
0
or
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
return
if
device_type
==
'cuda'
:
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
self
.
_non_model_data_cuda_list
.
append
(
self
.
_overall_cuda_list
[
-
1
]
-
self
.
_model_data_cuda_list
[
-
1
])
else
:
self
.
_non_model_data_cuda_list
.
append
(
val
)
elif
device_type
==
'cpu'
:
elif
device_type
==
'cpu'
:
if
val
is
None
:
if
len
(
self
.
_overall_cuda_list
)
==
0
or
len
(
self
.
_model_data_cuda_list
)
==
0
:
return
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
self
.
_non_model_data_cpu_list
.
append
(
self
.
_overall_cpu_list
[
-
1
]
-
self
.
_model_data_cpu_list
[
-
1
])
else
:
self
.
_non_model_data_cuda_list
.
append
(
val
)
else
:
else
:
raise
TypeError
raise
TypeError
...
...
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
4b055351
...
@@ -76,8 +76,7 @@ class RuntimeMemTracer():
...
@@ -76,8 +76,7 @@ class RuntimeMemTracer():
def
_post_backward
(
self
):
def
_post_backward
(
self
):
cuda_volume
=
self
.
param_op_hook
.
mem_monitor
.
finish
()
cuda_volume
=
self
.
param_op_hook
.
mem_monitor
.
finish
()
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_volume
)
self
.
_memstats
.
append_non_model_data
(
'cuda'
,
cuda_volume
-
self
.
_memstats
.
last_model_data
(
'cuda'
))
self
.
_memstats
.
append_non_model_data
(
'cuda'
)
self
.
grad_hook
.
remove_grad_hook
()
self
.
grad_hook
.
remove_grad_hook
()
self
.
_restore_params
()
self
.
_restore_params
()
...
...
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
View file @
4b055351
...
@@ -5,7 +5,7 @@ from typing import List
...
@@ -5,7 +5,7 @@ from typing import List
import
torch
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.memory_tracer
import
MemStats
,
SyncCudaMemoryMonitor
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.tensor.param_op_hook
import
ColoParamOpHook
from
colossalai.tensor.param_op_hook
import
ColoParamOpHook
...
@@ -51,7 +51,7 @@ class GradMemTracerHook():
...
@@ -51,7 +51,7 @@ class GradMemTracerHook():
class
ParamMemTracerHook
(
ColoParamOpHook
):
class
ParamMemTracerHook
(
ColoParamOpHook
):
def
__init__
(
self
,
memstats
,
gradstats
:
GradMemStats
)
->
None
:
def
__init__
(
self
,
memstats
:
MemStats
,
gradstats
:
GradMemStats
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
_training_phase
=
TrainingPhase
.
FORWARD
self
.
_training_phase
=
TrainingPhase
.
FORWARD
self
.
_memstats
=
memstats
self
.
_memstats
=
memstats
...
@@ -92,7 +92,9 @@ class ParamMemTracerHook(ColoParamOpHook):
...
@@ -92,7 +92,9 @@ class ParamMemTracerHook(ColoParamOpHook):
def
pre_op
(
self
,
params
):
def
pre_op
(
self
,
params
):
cuda_volume
=
self
.
mem_monitor
.
finish
()
cuda_volume
=
self
.
mem_monitor
.
finish
()
self
.
_memstats
.
append_model_data
(
'cuda'
,
cuda_volume
)
last_model_data_val
=
self
.
_memstats
.
last_model_data
(
'cuda'
)
if
last_model_data_val
is
not
None
:
self
.
_memstats
.
append_non_model_data
(
'cuda'
,
cuda_volume
-
last_model_data_val
)
self
.
_allocate_params_on_cuda
(
params
)
self
.
_allocate_params_on_cuda
(
params
)
self
.
sample_model_data
(
params
)
self
.
sample_model_data
(
params
)
self
.
mem_monitor
.
start
()
self
.
mem_monitor
.
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment