Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
31922110
Unverified
Commit
31922110
authored
Nov 18, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 18, 2022
Browse files
[Gemini] memory trace hook (#1978)
parent
0529fcde
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
2 deletions
+10
-2
colossalai/gemini/ophooks/mem_trace_hook.py
colossalai/gemini/ophooks/mem_trace_hook.py
+10
-2
No files found.
colossalai/gemini/ophooks/mem_trace_hook.py
View file @
31922110
...
@@ -5,6 +5,9 @@ from colossalai.gemini.ophooks import BaseOpHook
...
@@ -5,6 +5,9 @@ from colossalai.gemini.ophooks import BaseOpHook
class
MemTracerOpHook
(
BaseOpHook
):
class
MemTracerOpHook
(
BaseOpHook
):
"""
TODO() what if parameters are sharded by multiple submodules.
"""
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -14,8 +17,8 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -14,8 +17,8 @@ class MemTracerOpHook(BaseOpHook):
self
.
_cur_model_data_vol
=
0
self
.
_cur_model_data_vol
=
0
def
_move_module_to_dev
(
self
,
module
,
dev
:
str
)
->
int
:
def
_move_module_to_dev
(
self
,
module
,
dev
:
str
)
->
int
:
"""
_move_module_to_dev
"""
move module to
cuda
move module to
target dev
Args:
Args:
module (torch.nn.Module): a PyTorch module
module (torch.nn.Module): a PyTorch module
dev (torch.device): the target device
dev (torch.device): the target device
...
@@ -49,6 +52,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -49,6 +52,7 @@ class MemTracerOpHook(BaseOpHook):
if
module
.
training
:
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
self
.
_non_model_data_list
.
append
(
cuda_volume
-
comm_volume
)
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
...
@@ -65,6 +69,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -65,6 +69,7 @@ class MemTracerOpHook(BaseOpHook):
if
module
.
training
:
if
module
.
training
:
cuda_volume
=
self
.
mem_monitor
.
finish
()
cuda_volume
=
self
.
mem_monitor
.
finish
()
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
comm_volume
=
self
.
_move_module_to_dev
(
module
,
'cpu'
)
self
.
_non_model_data_list
.
append
(
cuda_volume
-
comm_volume
)
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def
pre_iter
(
self
):
def
pre_iter
(
self
):
...
@@ -74,6 +79,9 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -74,6 +79,9 @@ class MemTracerOpHook(BaseOpHook):
self
.
mem_monitor
.
finish
()
self
.
mem_monitor
.
finish
()
# print(f'post_iter')
# print(f'post_iter')
def
print_non_model_data
(
self
):
print
(
self
.
_non_model_data_list
)
def
save_results
(
self
,
filename
):
def
save_results
(
self
,
filename
):
self
.
mem_monitor
.
save
(
filename
)
self
.
mem_monitor
.
save
(
filename
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment