Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d3446892
Commit
d3446892
authored
Mar 04, 2022
by
Jie Zhu
Committed by
Frank Lee
Mar 11, 2022
Browse files
[profiler] primary memory tracer
parent
dfc3fafe
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
19 deletions
+93
-19
colossalai/engine/_base_engine.py
colossalai/engine/_base_engine.py
+2
-0
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+78
-17
colossalai/utils/timer.py
colossalai/utils/timer.py
+10
-0
tests/test_config/test_load_config.py
tests/test_config/test_load_config.py
+3
-2
No files found.
colossalai/engine/_base_engine.py
View file @
d3446892
...
...
@@ -26,6 +26,8 @@ class Engine:
:type gradient_handlers: list
:param clip_grad_norm: The norm of gradient clipping
:type clip_grad_norm: float, optional
:param ophook_list: List of ophook
:type ophook_list: list
:param verbose: whether to display log info
:type verbose: bool
"""
...
...
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
d3446892
from
re
import
S
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch
from
.
import
BaseOpHook
from
concurrent.futures
import
ThreadPoolExecutor
...
...
@@ -5,18 +7,20 @@ from colossalai.registry import OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
time
import
sleep
,
time
import
pickle
from
typing
import
Union
,
Optional
from
colossalai.core
import
global_context
as
gpc
def
get_cuda_memory_used
(
device
)
:
def
get_cuda_memory_used
(
device
:
Optional
[
torch
.
device
])
->
int
:
"""
Get the free memory info of device.
Notice that for CPU, this function will return 1/N of the total free memory,
where N is the world size.
"""
ret
=
torch
.
cuda
.
memory_allocated
()
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
# get the peak memory to report correct data, so reset the counter for the next call
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
ret
...
...
@@ -34,6 +38,9 @@ class AsyncMemoryMonitor:
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
def
set_interval
(
self
,
power
:
int
):
self
.
interval
=
1
/
(
10
**
power
)
...
...
@@ -75,21 +82,64 @@ class AsyncMemoryMonitor:
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
@
OPHOOKS
.
register_module
class
MemTracerOpHook
(
BaseOpHook
):
def
__init__
(
self
,
niter
=
5
):
'''
Collect GPU memory usage information
Args:
warmup (int): This parameter indicates how many iterations to truncate
before profiling, e.g. set to 5 and the data will start from 6-th iteration
refreshrate (int): This parameter decides the frequency of write file.
datafile(string): the name of the stats data file
Attributes:
_warmup (int): warmup iterations
_refreshrate(int): how many iterations we shall refresh the file
_logger (colossalai.logging.logger): output log file
_curiter (int): current iteration number
_count (int): the number of times the data file was written
_data_prefix (string): the prefix of the stats data file
_rank (int): the rank of current node
'''
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
super
().
__init__
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_niter
=
niter
self
.
_curiter
=
0
self
.
_logger
=
get_dist_logger
()
self
.
_count
=
0
self
.
_warmup
=
warmup
self
.
_refreshrate
=
refreshrate
self
.
_data_prefix
=
data_prefix
# in distributed environment
if
gpc
.
is_initialized
(
ParallelMode
.
GLOBAL
):
self
.
_rank
=
gpc
.
get_global_rank
()
else
:
self
.
_rank
=
0
def
_isvalid
(
self
,
module
)
->
bool
:
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
return
module
.
training
@
property
def
refreshrate
(
self
)
->
int
:
return
self
.
_refreshrate
def
_isvalid
(
self
,
module
):
return
module
.
training
and
self
.
_curiter
<
self
.
_niter
@
property
def
warmup
(
self
)
->
int
:
return
self
.
_warmup
def
niter
(
self
):
return
self
.
_niter
@
property
def
curiter
(
self
)
->
int
:
return
self
.
_curiter
@
property
def
valid_iter
(
self
)
->
int
:
return
self
.
curiter
-
self
.
warmup
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
self
.
_isvalid
(
module
):
...
...
@@ -103,14 +153,12 @@ class MemTracerOpHook(BaseOpHook):
self
.
_logger
.
debug
(
f
'FWD POST
{
module
.
__class__
.
__name__
}
'
)
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
self
.
_isvalid
(
module
):
self
.
async_mem_monitor
.
finish
()
self
.
async_mem_monitor
.
start
()
self
.
_logger
.
debug
(
f
'BWD PRE
{
module
.
__class__
.
__name__
}
'
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
self
.
_isvalid
(
module
):
self
.
async_mem_monitor
.
finish
()
self
.
_logger
.
debug
(
f
'BWD POST
{
module
.
__class__
.
__name__
}
'
)
...
...
@@ -120,11 +168,24 @@ class MemTracerOpHook(BaseOpHook):
def
post_iter
(
self
):
self
.
async_mem_monitor
.
finish
()
if
self
.
_curiter
==
self
.
_niter
:
# in the warmup stage
if
self
.
_curiter
<
self
.
warmup
:
# TODO: record time and adaptively change sampling rate
pass
elif
self
.
_curiter
==
self
.
_warmup
:
self
.
async_mem_monitor
.
clear
()
else
:
# every `refreshrate` times, refresh the file
if
self
.
valid_iter
!=
0
and
self
.
valid_iter
%
self
.
refreshrate
==
0
:
# output file info
self
.
_logger
.
info
(
f
'dump a memory statistics as pickle to ./memstats.pkl'
)
self
.
save_results
(
"memstats.pkl"
)
f
'dump a memory statistics as pickle to
{
self
.
_dataprefix
}
-
{
self
.
_rank
}
.pkl'
)
self
.
save_results
()
self
.
_count
+=
1
self
.
_logger
.
debug
(
f
'data file has been refreshed
{
self
.
_count
}
times'
)
# finish a iteration
self
.
_curiter
+=
1
def
save_results
(
self
,
filename
):
self
.
async_mem_monitor
.
save
(
filename
)
def
save_results
(
self
):
datafile
=
f
"
{
self
.
_data_prefix
}
-
{
self
.
_rank
}
.pkl"
self
.
async_mem_monitor
.
save
(
datafile
)
colossalai/utils/timer.py
View file @
d3446892
...
...
@@ -19,6 +19,11 @@ class Timer:
def
has_history
(
self
):
return
len
(
self
.
_history
)
!=
0
@
property
def
current_time
(
self
)
->
float
:
synchronize
()
return
time
.
time
()
def
start
(
self
):
"""Fisrtly synchronize cuda, reset the clock and then start the timer.
"""
...
...
@@ -27,6 +32,11 @@ class Timer:
self
.
_start_time
=
time
.
time
()
self
.
_started
=
True
def
lap
(
self
):
"""lap time and return elapsed time
"""
return
self
.
current_time
-
self
.
_start_time
def
stop
(
self
,
keep_in_history
:
bool
=
False
):
"""Stop the timer and record the start-stop time interval.
...
...
tests/test_config/test_load_config.py
View file @
d3446892
...
...
@@ -22,6 +22,7 @@ def test_load_config():
@
pytest
.
mark
.
cpu
def
test_load_ophooks
():
dict
=
{
'type'
:
'MemTracerOpHook'
,
'
ni
te
r
'
:
2
}
dict
=
{
'type'
:
'MemTracerOpHook'
,
'
warmup'
:
10
,
'refreshra
te'
:
2
0
}
ophook
=
build_ophooks
(
dict
)
assert
ophook
.
niter
()
==
2
assert
ophook
.
refreshrate
==
20
assert
ophook
.
warmup
==
10
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment