Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0529fcde
Unverified
Commit
0529fcde
authored
Nov 18, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 18, 2022
Browse files
[Gemini] independent runtime tracer (#1974)
parent
0da1d003
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
271 additions
and
143 deletions
+271
-143
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+2
-1
colossalai/gemini/memory_tracer/memory_monitor.py
colossalai/gemini/memory_tracer/memory_monitor.py
+147
-142
colossalai/gemini/memory_tracer/module_tracer_wrapper.py
colossalai/gemini/memory_tracer/module_tracer_wrapper.py
+36
-0
colossalai/gemini/ophooks/mem_trace_hook.py
colossalai/gemini/ophooks/mem_trace_hook.py
+86
-0
No files found.
colossalai/gemini/memory_tracer/__init__.py
View file @
0529fcde
...
...
@@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
# isort:skip
from
.chunk_memstats_collector
import
ChunkMemStatsCollector
# isort:skip
from
.static_memstats_collector
import
StaticMemStatsCollector
# isort:skip
from
.module_tracer_wrapper
import
MemtracerWrapper
# isort:skip
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'ChunkMemStatsCollector'
,
'StaticMemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
'StaticMemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
,
'MemtracerWrapper'
]
colossalai/gemini/memory_tracer/memory_monitor.py
View file @
0529fcde
from
abc
import
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
json
import
torch
from
colossalai.utils
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
class
MemoryMonitor
:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def
__init__
(
self
):
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
@
abstractmethod
def
start
(
self
):
pass
@
abstractmethod
def
finish
(
self
):
pass
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
class
AsyncMemoryMonitor
(
MemoryMonitor
):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
self
.
keep_measuring
=
False
current_device
=
get_current_device
()
def
_set_cuda_device
():
torch
.
cuda
.
set_device
(
current_device
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
colo_device_memory_used
(
get_current_device
()),
)
sleep
(
self
.
interval
)
return
max_usage
class
SyncCudaMemoryMonitor
(
MemoryMonitor
):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
def
start
(
self
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
reset_peak_memory_stats
()
def
finish
(
self
):
torch
.
cuda
.
synchronize
()
self
.
time_stamps
.
append
(
time
())
max_usage
=
torch
.
cuda
.
max_memory_allocated
()
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
import
json
from
abc
import
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
torch
from
colossalai.utils
import
colo_device_memory_used
,
get_current_device
class
MemoryMonitor
:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def
__init__
(
self
):
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
@
abstractmethod
def
start
(
self
):
pass
@
abstractmethod
def
finish
(
self
):
pass
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
class
AsyncMemoryMonitor
(
MemoryMonitor
):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
self
.
keep_measuring
=
False
current_device
=
get_current_device
()
def
_set_cuda_device
():
torch
.
cuda
.
set_device
(
current_device
)
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
initializer
=
_set_cuda_device
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
colo_device_memory_used
(
get_current_device
()),
)
sleep
(
self
.
interval
)
return
max_usage
class
SyncCudaMemoryMonitor
(
MemoryMonitor
):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def
__init__
(
self
,
power
:
int
=
10
):
super
().
__init__
()
def
start
(
self
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
reset_peak_memory_stats
()
def
finish
(
self
)
->
int
:
"""
return max gpu memory used since latest `start()`.
Returns:
int: max GPU memory
"""
torch
.
cuda
.
synchronize
()
self
.
time_stamps
.
append
(
time
())
max_usage
=
torch
.
cuda
.
max_memory_allocated
()
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
colossalai/gemini/memory_tracer/module_tracer_wrapper.py
0 → 100644
View file @
0529fcde
from
colossalai.gemini.ophooks
import
register_ophooks_recursively
from
colossalai.gemini.ophooks.mem_trace_hook
import
MemTracerOpHook
__all__
=
[
'MemtracerWrapper'
]
class
_Wrapper
():
def
__init__
(
self
,
model
,
ophook_list
):
self
.
_ophook_list
=
ophook_list
self
.
_model
=
model
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
_model
(
*
args
,
**
kwargs
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
_model
.
forward
(
*
args
,
**
kwargs
)
def
backward
(
self
,
loss
):
loss
.
backward
()
for
ophook
in
self
.
_ophook_list
:
ophook
.
post_iter
()
def
save_results
(
self
,
filename
):
for
ophook
in
self
.
_ophook_list
:
ophook
.
save_results
(
filename
)
def
show_mem_stats
(
self
):
self
.
_ophook_list
[
0
].
show_mem_stats
()
def
MemtracerWrapper
(
model
):
ophook_list
=
[
MemTracerOpHook
()]
register_ophooks_recursively
(
model
,
ophook_list
)
engine
=
_Wrapper
(
model
,
ophook_list
)
return
engine
colossalai/gemini/ophooks/mem_trace_hook.py
0 → 100644
View file @
0529fcde
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.ophooks
import
BaseOpHook
class
MemTracerOpHook
(
BaseOpHook
):
def
__init__
(
self
):
super
().
__init__
()
self
.
mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_cur_non_model_data_vol
=
0
self
.
_non_model_data_list
=
[]
self
.
_cur_model_data_vol
=
0
def
_move_module_to_dev
(
self
,
module
,
dev
:
str
)
->
int
:
"""_move_module_to_dev
move module to cuda
Args:
module (torch.nn.Module): a PyTorch module
dev (torch.device): the target device
Returns:
int: the data volume of this module on the cuda
"""
assert
isinstance
(
dev
,
str
),
f
"device should be a str not torch.device"
comm_volume
=
0
for
p
in
module
.
parameters
():
if
p
.
data
.
device
.
type
!=
dev
:
p
.
data
=
p
.
data
.
to
(
dev
)
comm_volume
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
if
p
.
grad
is
not
None
:
if
p
.
grad
.
device
.
type
!=
dev
:
p
.
grad
=
p
.
grad
.
to
(
dev
)
comm_volume
+=
p
.
grad
.
numel
()
*
p
.
grad
.
element_size
()
if
dev
==
'cuda'
:
self
.
_cur_model_data_vol
=
comm_volume
return
comm_volume
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cuda'
)
self
.
mem_monitor
.
start
()
# print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB')
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
self
.
_move_module_to_dev
(
module
,
'cuda'
)
self
.
mem_monitor
.
start
()
# print(f'BWD PRE {module.__class__.__name__}')
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
# bwd Op will generate grad. comm_volume is grad + data volume on cuda.
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def
pre_iter
(
self
):
pass
def
post_iter
(
self
):
self
.
mem_monitor
.
finish
()
# print(f'post_iter')
def
save_results
(
self
,
filename
):
self
.
mem_monitor
.
save
(
filename
)
def
show_mem_stats
(
self
):
start_timestamp
=
min
(
self
.
mem_monitor
.
time_stamps
)
self
.
mem_monitor
.
time_stamps
=
[
elem
-
start_timestamp
for
elem
in
self
.
mem_monitor
.
time_stamps
]
min_mem_used
=
min
(
self
.
mem_monitor
.
mem_stats
)
self
.
mem_monitor
.
mem_stats
=
[
elem
-
min_mem_used
for
elem
in
self
.
mem_monitor
.
mem_stats
]
print
(
self
.
mem_monitor
.
time_stamps
)
print
(
self
.
mem_monitor
.
mem_stats
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment