Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
340e59f9
Unverified
Commit
340e59f9
authored
Apr 13, 2022
by
HELSON
Committed by
GitHub
Apr 13, 2022
Browse files
[utils] add synchronized cuda memory monitor (#740)
parent
e6212f56
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
110 deletions
+149
-110
colossalai/trainer/hooks/_mem_tracer_hook.py
colossalai/trainer/hooks/_mem_tracer_hook.py
+4
-4
colossalai/utils/memory_tracer/__init__.py
colossalai/utils/memory_tracer/__init__.py
+2
-2
colossalai/utils/memory_tracer/memory_monitor.py
colossalai/utils/memory_tracer/memory_monitor.py
+142
-103
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+1
-1
No files found.
colossalai/trainer/hooks/_mem_tracer_hook.py
View file @
340e59f9
from
cgitb
import
Hook
from
colossalai.registry
import
HOOKS
from
colossalai.registry
import
HOOKS
from
torch
import
Tensor
from
torch
import
Tensor
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
from
._metric_hook
import
LearningRateMetric
,
MetricHook
@
HOOKS
.
register_module
@
HOOKS
.
register_module
class
MemTraceHook
(
BaseHook
):
class
MemTraceHook
(
BaseHook
):
...
@@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
...
@@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
This hook is used to record memory usage info, and pass to trainer.states
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
priority
:
int
=
0
,
priority
:
int
=
0
,
...
...
colossalai/utils/memory_tracer/__init__.py
View file @
340e59f9
from
.
async_memtrace
r
import
AsyncMemoryMonitor
from
.
memory_monito
r
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
from
.memstats_collector
import
MemStatsCollector
__all__
=
[
'AsyncMemoryMonitor'
,
'MemStatsCollector'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
]
colossalai/utils/memory_tracer/
async_memtrace
r.py
→
colossalai/utils/memory_tracer/
memory_monito
r.py
View file @
340e59f9
from
abc
import
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
from
time
import
sleep
,
time
import
pickle
import
json
import
torch
import
torch
...
@@ -8,7 +9,42 @@ from colossalai.utils.memory import colo_device_memory_used
...
@@ -8,7 +9,42 @@ from colossalai.utils.memory import colo_device_memory_used
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
class
AsyncMemoryMonitor
:
class
MemoryMonitor
:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def
__init__
(
self
):
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
@
abstractmethod
def
start
(
self
):
pass
@
abstractmethod
def
finish
(
self
):
pass
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
class
AsyncMemoryMonitor
(
MemoryMonitor
):
"""
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
at interval of `1/(10**power)` sec.
...
@@ -31,15 +67,15 @@ class AsyncMemoryMonitor:
...
@@ -31,15 +67,15 @@ class AsyncMemoryMonitor:
async_mem_monitor.finish()
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
async_mem_monitor.save('log.pkl')
Args:
Args:
power (int, optional): the power of time interva. Defaults to 10.
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar
\
: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
https://arxiv.org/abs/2108.05818
"""
"""
def
__init__
(
self
,
power
:
int
=
10
):
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
self
.
keep_measuring
=
False
self
.
keep_measuring
=
False
current_device
=
get_current_device
()
current_device
=
get_current_device
()
...
@@ -50,11 +86,6 @@ class AsyncMemoryMonitor:
...
@@ -50,11 +86,6 @@ class AsyncMemoryMonitor:
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
monitor_thread
=
None
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
self
.
interval
=
1
/
(
10
**
power
)
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
def
set_interval
(
self
,
power
:
int
):
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
clear
()
...
@@ -70,8 +101,10 @@ class AsyncMemoryMonitor:
...
@@ -70,8 +101,10 @@ class AsyncMemoryMonitor:
def
finish
(
self
):
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
if
self
.
keep_measuring
is
False
:
return
0
return
0
self
.
keep_measuring
=
False
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
self
.
mem_stats
.
append
(
max_usage
)
...
@@ -87,17 +120,23 @@ class AsyncMemoryMonitor:
...
@@ -87,17 +120,23 @@ class AsyncMemoryMonitor:
sleep
(
self
.
interval
)
sleep
(
self
.
interval
)
return
max_usage
return
max_usage
@
property
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
class
SyncCudaMemoryMonitor
(
MemoryMonitor
):
with
open
(
filename
,
"wb"
)
as
f
:
"""
pickle
.
dump
(
self
.
state_dict
(),
f
)
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def
clear
(
self
):
def
__init__
(
self
,
power
:
int
=
10
):
self
.
mem_stats
.
clear
()
super
().
__init__
()
self
.
time_stamps
.
clear
()
def
start
(
self
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
reset_peak_memory_stats
()
def
finish
(
self
):
torch
.
cuda
.
synchronize
()
self
.
time_stamps
.
append
(
time
())
max_usage
=
torch
.
cuda
.
max_memory_allocated
()
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
colossalai/utils/memory_tracer/memstats_collector.py
View file @
340e59f9
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory_tracer
.async_memtracer
import
AsyncMemoryMonitor
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
import
torch
import
torch
import
time
import
time
from
typing
import
List
from
typing
import
List
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment