Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
10e28264
Commit
10e28264
authored
Mar 09, 2022
by
Jiarui Fang
Committed by
Frank Lee
Mar 11, 2022
Browse files
move async memory to an individual directory (#345)
parent
425bb0df
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
130 additions
and
89 deletions
+130
-89
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+3
-89
colossalai/utils/memory_tracer/__init__.py
colossalai/utils/memory_tracer/__init__.py
+3
-0
colossalai/utils/memory_tracer/async_memtracer.py
colossalai/utils/memory_tracer/async_memtracer.py
+108
-0
colossalai/utils/memory_tracer/test_async_memtracer.py
colossalai/utils/memory_tracer/test_async_memtracer.py
+16
-0
No files found.
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
10e28264
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch
import
torch
from
.
import
BaseOpHook
from
colossalai.engine.ophooks
import
BaseOpHook
from
concurrent.futures
import
ThreadPoolExecutor
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
time
import
sleep
,
time
import
pickle
from
typing
import
Optional
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
import
math
def
get_cuda_memory_used
(
device
:
Optional
[
torch
.
device
])
->
int
:
"""Get the free memory info of device.
Notice that for CPU, this function will return 1/N of the total free memory,
where N is the world size.
:param device: device id
:type device: torch.device
:return: current memory usage, sized by MB
:rtype: int
"""
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
# get the peak memory to report correct data, so reset the counter for the next call
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
ret
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
class
AsyncMemoryMonitor
:
import
math
"""
An Async Mem Monitor runing during computing. Sampling GPU memory usage of the current GPU
at interval of 1/(10**power) sec.
:param power: the power of time interval, defaults to 10
:type power: int
"""
def
__init__
(
self
,
power
:
int
=
10
):
self
.
keep_measuring
=
False
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
dev
=
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
get_cuda_memory_used
(
dev
),
)
sleep
(
self
.
interval
)
return
max_usage
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"wb"
)
as
f
:
pickle
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
...
...
colossalai/utils/memory_tracer/__init__.py
0 → 100644
View file @
10e28264
from
.async_memtracer
import
AsyncMemoryMonitor
__all__
=
[
'AsyncMemoryMonitor'
]
colossalai/utils/memory_tracer/async_memtracer.py
0 → 100644
View file @
10e28264
from
concurrent.futures
import
ThreadPoolExecutor
from
time
import
sleep
,
time
import
pickle
from
colossalai.utils
import
get_current_device
import
torch
def
_get_cuda_memory_used
(
device
:
torch
.
device
)
->
int
:
"""
Get the free memory info of device.
:param device: device id
:type device: torch.device
:return: current memory usage, sized by MB
:rtype: int
"""
assert
device
.
type
==
'cuda'
ret
:
int
=
torch
.
cuda
.
memory_allocated
(
device
)
# get the peak memory to report correct data, so reset the counter for the next call
if
hasattr
(
torch
.
cuda
,
"reset_peak_memory_stats"
):
# pytorch 1.4+
torch
.
cuda
.
reset_peak_memory_stats
(
device
)
return
ret
class
AsyncMemoryMonitor
:
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of 1/(10**power) sec.
:param power: the power of time interval, defaults to 10
:type power: int
Usage:
```python
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
```
"""
def
__init__
(
self
,
power
:
int
=
10
):
self
.
keep_measuring
=
False
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
self
.
monitor_thread
=
None
self
.
interval
=
1
/
(
10
**
power
)
self
.
time_stamps
=
[]
self
.
mem_stats
=
[]
def
__len__
(
self
):
return
len
(
self
.
mem_stats
)
def
set_interval
(
self
,
power
:
int
):
self
.
clear
()
self
.
interval
=
1
/
(
10
**
power
)
def
is_measuring
(
self
):
return
self
.
keep_measuring
def
start
(
self
):
self
.
keep_measuring
=
True
self
.
monitor_thread
=
self
.
executor
.
submit
(
self
.
_measure_usage
)
def
finish
(
self
):
if
self
.
keep_measuring
is
False
:
return
0
self
.
keep_measuring
=
False
max_usage
=
self
.
monitor_thread
.
result
()
self
.
monitor_thread
=
None
self
.
time_stamps
.
append
(
time
())
self
.
mem_stats
.
append
(
max_usage
)
return
max_usage
def
_measure_usage
(
self
):
max_usage
=
0
while
self
.
keep_measuring
:
max_usage
=
max
(
max_usage
,
_get_cuda_memory_used
(
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)),
)
sleep
(
self
.
interval
)
return
max_usage
def
state_dict
(
self
):
return
{
"time_stamps"
:
self
.
time_stamps
,
"mem_stats"
:
self
.
mem_stats
,
}
def
save
(
self
,
filename
):
with
open
(
filename
,
"wb"
)
as
f
:
print
(
self
.
state_dict
())
pickle
.
dump
(
self
.
state_dict
(),
f
)
def
clear
(
self
):
self
.
mem_stats
.
clear
()
self
.
time_stamps
.
clear
()
colossalai/utils/memory_tracer/test_async_memtracer.py
0 → 100644
View file @
10e28264
from
async_memtracer
import
AsyncMemoryMonitor
import
torch
if
__name__
==
'__main__'
:
async_mem_monitor
=
AsyncMemoryMonitor
()
input
=
torch
.
randn
(
2
,
20
).
cuda
()
OP1
=
torch
.
nn
.
Linear
(
20
,
30
).
cuda
()
OP2
=
torch
.
nn
.
Linear
(
30
,
40
).
cuda
()
async_mem_monitor
.
start
()
output
=
OP1
(
input
)
async_mem_monitor
.
finish
()
async_mem_monitor
.
start
()
output
=
OP2
(
output
)
async_mem_monitor
.
finish
()
async_mem_monitor
.
save
(
'log.pkl'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment