Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
223332ff
Unverified
Commit
223332ff
authored
Dec 05, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 05, 2022
Browse files
[Gemini] rename ParamTracerWrapper -> RuntimeMemTracer (#2073)
parent
9f828ef3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
16 deletions
+30
-16
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
colossalai/gemini/memory_tracer/runtime_mem_tracer.py
+16
-9
tests/test_gemini/test_runtime_mem_tracer.py
tests/test_gemini/test_runtime_mem_tracer.py
+14
-7
No files found.
colossalai/gemini/memory_tracer/
param_tracer_wrapp
er.py
→
colossalai/gemini/memory_tracer/
runtime_mem_trac
er.py
View file @
223332ff
import
torch.nn
import
torch.nn
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.ophooks.param_trace_hook
import
ParamTracerHook
,
GradHook
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.ophooks.param_trace_hook
import
GradHook
,
ParamTracerHook
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
__all__
=
[
'RuntimeMemTracer'
]
__all__
=
[
'ParamTracerWrapper'
]
class
ParamTracerWrapp
er
():
class
RuntimeMemTrac
er
():
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
=
torch
.
half
):
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
dtype
:
torch
.
dtype
=
torch
.
half
):
super
().
__init__
()
super
().
__init__
()
...
@@ -25,12 +26,18 @@ class ParamTracerWrapper():
...
@@ -25,12 +26,18 @@ class ParamTracerWrapper():
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
def
_save_param_data_on_cpu
(
self
):
def
_backup_params
(
self
):
"""
The function is called before forward. Backup model params on cpu.
"""
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
self
.
cpu_param_data_dict
[
p
]
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
)
self
.
cpu_param_data_dict
[
p
]
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
)
self
.
cpu_param_data_dict
[
p
].
copy_
(
p
.
data
)
self
.
cpu_param_data_dict
[
p
].
copy_
(
p
.
data
)
def
_restore_param_data
(
self
):
def
_restore_params
(
self
):
"""
This function is called after backward. Restore model params.
"""
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
,
requires_grad
=
p
.
data
.
requires_grad
)
p
.
data
=
torch
.
empty
(
p
.
data
.
shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
,
requires_grad
=
p
.
data
.
requires_grad
)
p
.
data
.
copy_
(
self
.
cpu_param_data_dict
[
p
])
p
.
data
.
copy_
(
self
.
cpu_param_data_dict
[
p
])
...
@@ -38,7 +45,7 @@ class ParamTracerWrapper():
...
@@ -38,7 +45,7 @@ class ParamTracerWrapper():
def
_pre_forward
(
self
):
def
_pre_forward
(
self
):
self
.
_clear_cuda_mem_info
()
self
.
_clear_cuda_mem_info
()
self
.
_
save_param_data_on_cpu
()
self
.
_
backup_params
()
self
.
grad_hook
.
register_grad_hook
()
self
.
grad_hook
.
register_grad_hook
()
self
.
param_op_hook
.
mem_monitor
.
start
()
self
.
param_op_hook
.
mem_monitor
.
start
()
...
@@ -60,7 +67,7 @@ class ParamTracerWrapper():
...
@@ -60,7 +67,7 @@ class ParamTracerWrapper():
last_model_data
=
GLOBAL_CUDA_MEM_INFO
.
model_data_list
[
-
1
]
last_model_data
=
GLOBAL_CUDA_MEM_INFO
.
model_data_list
[
-
1
]
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
.
append
(
cuda_volume
-
last_model_data
)
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
.
append
(
cuda_volume
-
last_model_data
)
self
.
grad_hook
.
remove_grad_hook
()
self
.
grad_hook
.
remove_grad_hook
()
self
.
_restore_param
_data
()
self
.
_restore_param
s
()
def
_clear_cuda_mem_info
(
self
):
def
_clear_cuda_mem_info
(
self
):
GLOBAL_CUDA_MEM_INFO
.
model_data_list
.
clear
()
GLOBAL_CUDA_MEM_INFO
.
model_data_list
.
clear
()
...
@@ -72,4 +79,4 @@ class ParamTracerWrapper():
...
@@ -72,4 +79,4 @@ class ParamTracerWrapper():
for
buffer
in
self
.
module
.
buffers
():
for
buffer
in
self
.
module
.
buffers
():
buffer
.
data
=
buffer
.
cuda
()
buffer
.
data
=
buffer
.
cuda
()
if
torch
.
is_floating_point
(
buffer
):
if
torch
.
is_floating_point
(
buffer
):
buffer
.
data
=
buffer
.
data
.
to
(
self
.
dtype
)
buffer
.
data
=
buffer
.
data
.
to
(
self
.
dtype
)
\ No newline at end of file
tests/test_gemini/test_
para
m_tracer.py
→
tests/test_gemini/test_
runtime_me
m_tracer.py
View file @
223332ff
from
copy
import
deepcopy
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
colossalai.gemini.memory_tracer.param_tracer_wrapper
import
ParamTracerWrapper
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.model_data_memtracer
import
GLOBAL_CUDA_MEM_INFO
from
colossalai.gemini.memory_tracer.runtime_mem_tracer
import
RuntimeMemTracer
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
,
dtype
=
torch
.
half
):
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
,
dtype
=
torch
.
half
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
if
criterion
:
...
@@ -16,9 +20,9 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
...
@@ -16,9 +20,9 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
loss
=
loss
.
to
(
dtype
)
loss
=
loss
.
to
(
dtype
)
model
.
backward
(
loss
)
model
.
backward
(
loss
)
def
run_param_wrapper_testing
():
def
run_param_wrapper_testing
():
test_models
=
[
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
]
test_models
=
[
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
]
for
model_name
in
test_models
:
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
...
@@ -26,7 +30,8 @@ def run_param_wrapper_testing():
...
@@ -26,7 +30,8 @@ def run_param_wrapper_testing():
with
ColoInitContext
(
device
=
torch
.
device
(
'cpu'
)):
with
ColoInitContext
(
device
=
torch
.
device
(
'cpu'
)):
model
=
model_builder
(
checkpoint
=
False
)
model
=
model_builder
(
checkpoint
=
False
)
model
=
ParamTracerWrapper
(
model
)
model_bk
=
deepcopy
(
model
)
runtime_mem_tracer
=
RuntimeMemTracer
(
model
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
1
:
if
i
>
1
:
...
@@ -34,15 +39,17 @@ def run_param_wrapper_testing():
...
@@ -34,15 +39,17 @@ def run_param_wrapper_testing():
data
=
data
.
cuda
()
data
=
data
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
runtime_mem_tracer
,
data
,
label
,
criterion
,
False
)
for
p1
,
p2
in
zip
(
model_bk
.
parameters
(),
model
.
parameters
()):
torch
.
allclose
(
p1
.
to
(
torch
.
half
),
p2
)
cuda_non_model_data_list
=
np
.
array
(
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
)
/
1024
**
2
cuda_non_model_data_list
=
np
.
array
(
GLOBAL_CUDA_MEM_INFO
.
non_model_data_list
)
/
1024
**
2
print
(
"cuda_non_model_data_list"
,
len
(
cuda_non_model_data_list
))
print
(
"cuda_non_model_data_list"
,
len
(
cuda_non_model_data_list
))
# print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
# print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
del
model
del
model
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
run_param_wrapper_testing
()
run_param_wrapper_testing
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment