Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
340e59f9
Unverified
Commit
340e59f9
authored
Apr 13, 2022
by
HELSON
Committed by
GitHub
Apr 13, 2022
Browse files
[utils] add synchronized cuda memory monitor (#740)
parent
e6212f56
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
110 deletions
+149
-110
colossalai/trainer/hooks/_mem_tracer_hook.py
colossalai/trainer/hooks/_mem_tracer_hook.py
+4
-4
colossalai/utils/memory_tracer/__init__.py
colossalai/utils/memory_tracer/__init__.py
+2
-2
colossalai/utils/memory_tracer/memory_monitor.py
colossalai/utils/memory_tracer/memory_monitor.py
+142
-103
colossalai/utils/memory_tracer/memstats_collector.py
colossalai/utils/memory_tracer/memstats_collector.py
+1
-1
No files found.
colossalai/trainer/hooks/_mem_tracer_hook.py
View file @
340e59f9
from
cgitb
import
Hook
from
colossalai.registry
import
HOOKS
from
torch
import
Tensor
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
from
._metric_hook
import
LearningRateMetric
,
MetricHook
@
HOOKS
.
register_module
class
MemTraceHook
(
BaseHook
):
...
...
@@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""
def
__init__
(
self
,
priority
:
int
=
0
,
...
...
@@ -36,9 +36,9 @@ class MemTraceHook(BaseHook):
def
before_test_iter
(
self
,
trainer
):
self
.
_memory_monitor
.
start
()
return
super
().
before_test
(
trainer
)
def
after_test_iter
(
self
,
trainer
,
output
:
Tensor
,
label
:
Tensor
,
loss
:
Tensor
):
self
.
_memory_monitor
.
finish
()
trainer
.
states
[
'metrics'
][
'train'
]
=
self
.
_memory_monitor
.
state_dict
trainer
.
states
[
'metrics'
][
'test'
]
=
self
.
_memory_monitor
.
state_dict
return
super
().
after_test_iter
(
trainer
,
output
,
label
,
loss
)
\ No newline at end of file
return
super
().
after_test_iter
(
trainer
,
output
,
label
,
loss
)
colossalai/utils/memory_tracer/__init__.py
View file @
340e59f9
from
.
async_memtrace
r
import
AsyncMemoryMonitor
from
.
memory_monito
r
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
__all__
=
[
'AsyncMemoryMonitor'
,
'MemStatsCollector'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
]
colossalai/utils/memory_tracer/
async_memtrace
r.py
→
colossalai/utils/memory_tracer/
memory_monito
r.py
View file @
340e59f9
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
pickle
import
torch
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
class
AsyncMemoryMonitor
:
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
power
:
int
=
10
):
self
.
keep_measuring
=
False
current_device
=
get_current_device
()
def
_set_cuda_device
():
torch
.
cuda
.
set_device
(
current_device
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
colo_device_memory_used
(
get_current_device
()),
)
sleep
(
self
.
interval
)
return
max_usage
@
property
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
from
abc
import
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
json
import
torch
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
class
MemoryMonitor
:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def
__init__
(
self
):
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
@
abstractmethod
def
start
(
self
):
pass
@
abstractmethod
def
finish
(
self
):
pass
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
class
AsyncMemoryMonitor
(
MemoryMonitor
):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
self
.
keep_measuring
=
False
current_device
=
get_current_device
()
def
_set_cuda_device
():
torch
.
cuda
.
set_device
(
current_device
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
colo_device_memory_used
(
get_current_device
()),
)
sleep
(
self
.
interval
)
return
max_usage
class
SyncCudaMemoryMonitor
(
MemoryMonitor
):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
def
start
(
self
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
reset_peak_memory_stats
()
def
finish
(
self
):
torch
.
cuda
.
synchronize
()
self
.
time_stamps
.
append
(
time
())
max_usage
=
torch
.
cuda
.
max_memory_allocated
()
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
colossalai/utils/memory_tracer/memstats_collector.py
View file @
340e59f9
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory_tracer
.async_memtracer
import
AsyncMemoryMonitor
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
import
torch
import
time
from
typing
import
List
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment