Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
70a85569
Unverified
Commit
70a85569
authored
Dec 09, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 09, 2022
Browse files
[gemini] get the param visited order during runtime (#2108)
parent
61f31c3c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
48 additions
and
2 deletions
+48
-2
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+2
-1
colossalai/gemini/memory_tracer/memory_stats.py
colossalai/gemini/memory_tracer/memory_stats.py
+6
-0
colossalai/gemini/memory_tracer/param_runtime_order.py
colossalai/gemini/memory_tracer/param_runtime_order.py
+25
-0
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+4
-1
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+4
-0
tests/test_gemini/test_runtime_mem_tracer.py
tests/test_gemini/test_runtime_mem_tracer.py
+7
-0
No files found.
colossalai/gemini/memory_tracer/__init__.py
View file @
70a85569
from
.param_runtime_order
import
ParamRuntimeOrder
# isort:skip
from
.memory_stats
import
MemStats
# isort:skip
from
.memory_stats
import
MemStats
# isort:skip
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
from
.memstats_collector
import
MemStatsCollector
# isort:skip
...
@@ -6,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
...
@@ -6,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__
=
[
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'MemStats'
'StaticMemStatsCollector'
,
'MemStats'
,
'ParamRuntimeOrder'
]
]
colossalai/gemini/memory_tracer/memory_stats.py
View file @
70a85569
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
from
colossalai.gemini.memory_tracer
import
ParamRuntimeOrder
class
MemStats
(
object
):
class
MemStats
(
object
):
...
@@ -19,6 +21,8 @@ class MemStats(object):
...
@@ -19,6 +21,8 @@ class MemStats(object):
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_param_runtime_order
=
ParamRuntimeOrder
()
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
def
append_overall_data
(
self
,
device_type
:
str
,
val
:
float
):
if
device_type
==
'cuda'
:
if
device_type
==
'cuda'
:
self
.
_overall_cuda_list
.
append
(
val
)
self
.
_overall_cuda_list
.
append
(
val
)
...
@@ -112,3 +116,5 @@ class MemStats(object):
...
@@ -112,3 +116,5 @@ class MemStats(object):
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_param_runtime_order
.
clear
()
colossalai/gemini/memory_tracer/param_runtime_order.py
0 → 100644
View file @
70a85569
import
torch
class
ParamRuntimeOrder
(
object
):
"""ParamRuntimeOrder
Contain the order of parameters visited during runtime.
"""
def
__init__
(
self
)
->
None
:
self
.
param_visited_order
=
[]
def
append
(
self
,
param
:
torch
.
nn
.
Parameter
):
self
.
param_visited_order
.
append
(
param
)
def
generate
(
self
):
visited_set
=
set
()
for
p
in
self
.
param_visited_order
:
if
p
not
in
visited_set
:
yield
p
visited_set
.
add
(
p
)
del
visited_set
def
clear
(
self
):
self
.
param_visited_order
=
[]
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
70a85569
import
torch.nn
import
torch.nn
from
colossalai.gemini.memory_tracer
import
MemStats
from
colossalai.gemini.memory_tracer
import
MemStats
,
ParamRuntimeOrder
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemStats
,
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.gemini.ophooks.runtime_mem_tracer_hook
import
GradMemStats
,
GradMemTracerHook
,
ParamMemTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ColoParamOpHookManager
from
colossalai.tensor.param_op_hook
import
ColoParamOpHookManager
...
@@ -35,6 +35,9 @@ class RuntimeMemTracer():
...
@@ -35,6 +35,9 @@ class RuntimeMemTracer():
self
.
_cast_buffers_to_cuda_dtype
()
self
.
_cast_buffers_to_cuda_dtype
()
def
parameters_in_runtime_order
(
self
):
return
self
.
_memstats
.
_param_runtime_order
.
generate
()
def
memstats
(
self
):
def
memstats
(
self
):
return
self
.
_memstats
return
self
.
_memstats
...
...
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
View file @
70a85569
...
@@ -99,6 +99,10 @@ class ParamMemTracerHook(ColoParamOpHook):
...
@@ -99,6 +99,10 @@ class ParamMemTracerHook(ColoParamOpHook):
self
.
sample_model_data
(
params
)
self
.
sample_model_data
(
params
)
self
.
mem_monitor
.
start
()
self
.
mem_monitor
.
start
()
# register the order of visited.
for
p
in
params
:
self
.
_memstats
.
_param_runtime_order
.
append
(
p
)
def
post_op
(
self
,
params
):
def
post_op
(
self
,
params
):
self
.
_free_cuda_params
(
params
)
self
.
_free_cuda_params
(
params
)
...
...
tests/test_gemini/test_runtime_mem_tracer.py
View file @
70a85569
...
@@ -38,6 +38,13 @@ def test_runtime_mem_tracer():
...
@@ -38,6 +38,13 @@ def test_runtime_mem_tracer():
print
(
"cuda_non_model_data_list"
,
len
(
cuda_non_model_data_list
))
print
(
"cuda_non_model_data_list"
,
len
(
cuda_non_model_data_list
))
print
(
non_model_data_list
)
print
(
non_model_data_list
)
cnt1
=
0
for
p
in
runtime_mem_tracer
.
parameters_in_runtime_order
():
cnt1
+=
1
cnt2
=
0
for
p
in
model
.
parameters
():
cnt2
+=
1
assert
cnt2
==
cnt1
,
f
'visited param number
{
cnt1
}
vs real param number
{
cnt2
}
'
del
model
del
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment