Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a7adad9c
Unverified
Commit
a7adad9c
authored
Dec 05, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 05, 2022
Browse files
[Gemini] rename hooks related to runtime mem tracer (#2076)
parent
40b7d55b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
111 deletions
+15
-111
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+3
-3
colossalai/gemini/ophooks/_shard_grad_ophook.py
colossalai/gemini/ophooks/_shard_grad_ophook.py
+2
-1
colossalai/gemini/ophooks/mem_trace_hook.py
colossalai/gemini/ophooks/mem_trace_hook.py
+0
-100
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
colossalai/gemini/ophooks/runtime_mem_tracer_hook.py
+8
-5
tests/test_gemini/test_runtime_mem_tracer.py
tests/test_gemini/test_runtime_mem_tracer.py
+2
-2
No files found.
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
View file @
a7adad9c
import
torch.nn
import
torch.nn
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.ophooks.
para
m_trace_hook
import
GradHook
,
ParamTracerHook
from
colossalai.gemini.ophooks.
runtime_me
m_trace
r
_hook
import
Grad
MemTracer
Hook
,
Param
Mem
TracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
...
@@ -14,8 +14,8 @@ class RuntimeMemTracer():
...
@@ -14,8 +14,8 @@ class RuntimeMemTracer():
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
param_op_hook
=
ParamTracerHook
()
self
.
param_op_hook
=
Param
Mem
TracerHook
()
self
.
grad_hook
=
GradHook
(
module
)
self
.
grad_hook
=
Grad
MemTracer
Hook
(
module
)
self
.
cpu_param_data_dict
=
{}
self
.
cpu_param_data_dict
=
{}
for
p
in
module
.
parameters
():
for
p
in
module
.
parameters
():
...
...
colossalai/gemini/ophooks/_shard_grad_ophook.py
View file @
a7adad9c
import
torch
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
.
import
BaseOpHook
from
.
import
BaseOpHook
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
class
ShardGradHook
(
BaseOpHook
):
class
ShardGrad
MemTracer
Hook
(
BaseOpHook
):
"""
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
"""
...
...
colossalai/gemini/ophooks/mem_trace_hook.py
deleted
100644 → 0
View file @
40b7d55b
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.ophooks
import
BaseOpHook
class
MemTracerOpHook
(
BaseOpHook
):
"""
TODO() what if parameters are sharded by multiple submodules.
register buff on its father node
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_cur_non_model_data_vol
=
0
self
.
_non_model_data_list
=
[]
self
.
_cur_model_data_vol
=
0
def
_move_module_to_dev
(
self
,
module
,
dev
:
str
)
->
int
:
"""
move module to target dev
Args:
module (torch.nn.Module): a PyTorch module
dev (torch.device): the target device
Returns:
int: the data volume of this module on the cuda
"""
assert
isinstance
(
dev
,
str
),
f
"device should be a str not torch.device"
comm_volume
=
0
for
p
in
module
.
parameters
():
if
p
.
data
.
device
.
type
!=
dev
:
p
.
data
=
p
.
data
.
to
(
dev
)
comm_volume
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
if
p
.
grad
is
not
None
:
if
p
.
grad
.
device
.
type
!=
dev
:
p
.
grad
=
p
.
grad
.
to
(
dev
)
comm_volume
+=
p
.
grad
.
numel
()
*
p
.
grad
.
element_size
()
for
buf
in
module
.
buffers
():
if
buf
.
device
.
type
!=
dev
:
buf
.
data
=
buf
.
data
.
to
(
dev
)
comm_volume
+=
buf
.
data
.
numel
()
*
buf
.
data
.
element_size
()
if
dev
==
'cuda'
:
self
.
_cur_model_data_vol
=
comm_volume
return
comm_volume
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cuda'
)
self
.
mem_monitor
.
start
()
# print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB')
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
self
.
_non_model_data_list
.
append
(
cuda_volume
-
comm_volume
)
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
self
.
_move_module_to_dev
(
module
,
'cuda'
)
self
.
mem_monitor
.
start
()
# print(f'BWD PRE {module.__class__.__name__}')
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
# bwd Op will generate grad. comm_volume is grad + data volume on cuda.
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
self
.
_non_model_data_list
.
append
(
cuda_volume
-
comm_volume
)
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def
pre_iter
(
self
):
pass
def
post_iter
(
self
):
self
.
mem_monitor
.
finish
()
# print(f'post_iter')
def
print_non_model_data
(
self
):
print
(
self
.
_non_model_data_list
)
def
save_results
(
self
,
filename
):
self
.
mem_monitor
.
save
(
filename
)
def
show_mem_stats
(
self
):
start_timestamp
=
min
(
self
.
mem_monitor
.
time_stamps
)
self
.
mem_monitor
.
time_stamps
=
[
elem
-
start_timestamp
for
elem
in
self
.
mem_monitor
.
time_stamps
]
min_mem_used
=
min
(
self
.
mem_monitor
.
mem_stats
)
self
.
mem_monitor
.
mem_stats
=
[
elem
-
min_mem_used
for
elem
in
self
.
mem_monitor
.
mem_stats
]
print
(
self
.
mem_monitor
.
time_stamps
)
print
(
self
.
mem_monitor
.
mem_stats
)
colossalai/gemini/ophooks/
para
m_trace_hook.py
→
colossalai/gemini/ophooks/
runtime_me
m_trace
r
_hook.py
View file @
a7adad9c
...
@@ -6,9 +6,9 @@ from typing import List
...
@@ -6,9 +6,9 @@ from typing import List
import
torch
import
torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.gemini.tensor_utils
import
free_storage
,
alloc_storage
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.tensor_utils
import
alloc_storage
,
free_storage
from
colossalai.tensor.param_op_hook
import
ParamOpHook
class
TrainingPhase
(
Enum
):
class
TrainingPhase
(
Enum
):
...
@@ -16,7 +16,8 @@ class TrainingPhase(Enum):
...
@@ -16,7 +16,8 @@ class TrainingPhase(Enum):
BACKWARD
=
1
BACKWARD
=
1
class
GradHook
():
class
GradMemTracerHook
():
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
):
self
.
module
=
module
self
.
module
=
module
self
.
grad_hook_list
=
[]
self
.
grad_hook_list
=
[]
...
@@ -38,7 +39,7 @@ class GradHook():
...
@@ -38,7 +39,7 @@ class GradHook():
hook
.
remove
()
hook
.
remove
()
class
ParamTracerHook
(
ParamOpHook
):
class
Param
Mem
TracerHook
(
ParamOpHook
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -57,7 +58,9 @@ class ParamTracerHook(ParamOpHook):
...
@@ -57,7 +58,9 @@ class ParamTracerHook(ParamOpHook):
if
cur_dev
==
"cpu"
:
if
cur_dev
==
"cpu"
:
if
p
.
grad
is
not
None
and
p
.
grad
.
device
.
type
==
"cpu"
:
if
p
.
grad
is
not
None
and
p
.
grad
.
device
.
type
==
"cpu"
:
raise
NotImplementedError
(
"Only run in forward propagation"
)
raise
NotImplementedError
(
"Only run in forward propagation"
)
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
device
=
"cuda"
,
dtype
=
p
.
data
.
dtype
,
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
device
=
"cuda"
,
dtype
=
p
.
data
.
dtype
,
requires_grad
=
p
.
data
.
requires_grad
)
requires_grad
=
p
.
data
.
requires_grad
)
elif
cur_dev
==
"cuda"
:
elif
cur_dev
==
"cuda"
:
alloc_storage
(
p
.
data
)
alloc_storage
(
p
.
data
)
...
...
tests/test_gemini/test_runtime_mem_tracer.py
View file @
a7adad9c
...
@@ -29,7 +29,7 @@ def test_runtime_mem_tracer():
...
@@ -29,7 +29,7 @@ def test_runtime_mem_tracer():
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
torch
.
device
(
'cpu'
)):
with
ColoInitContext
(
device
=
torch
.
device
(
'cpu'
)):
model
=
model_builder
(
checkpoint
=
Tru
e
)
model
=
model_builder
(
checkpoint
=
Fals
e
)
model_bk
=
deepcopy
(
model
)
model_bk
=
deepcopy
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
...
@@ -47,7 +47,7 @@ def test_runtime_mem_tracer():
...
@@ -47,7 +47,7 @@ def test_runtime_mem_tracer():
cuda_non_model_data_list
=
np
.
array
(
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
)
/
1024
**
2
cuda_non_model_data_list
=
np
.
array
(
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
)
/
1024
**
2
print
(
"cuda_non_model_data_list"
,
len
(
cuda_non_model_data_list
))
print
(
"cuda_non_model_data_list"
,
len
(
cuda_non_model_data_list
))
#
print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
print
(
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
)
del
model
del
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment