Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6a9158f1
Unverified
Commit
6a9158f1
authored
Nov 30, 2022
by
Zihao
Committed by
GitHub
Nov 30, 2022
Browse files
[Gemini] free and allocate cuda memory by tensor.storage, add grad hook (#2040)
parent
1e885329
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
18 deletions
+40
-18
colossalai/gemini/memory_tracer/param_tracer_wrapper.py
colossalai/gemini/memory_tracer/param_tracer_wrapper.py
+9
-3
colossalai/gemini/ophooks/param_trace_hook.py
colossalai/gemini/ophooks/param_trace_hook.py
+16
-14
colossalai/gemini/tensor_utils.py
colossalai/gemini/tensor_utils.py
+14
-0
tests/test_gemini/test_param_tracer.py
tests/test_gemini/test_param_tracer.py
+1
-1
No files found.
colossalai/gemini/memory_tracer/param_tracer_wrapper.py
View file @
6a9158f1
import
torch.nn
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.ophooks.param_trace_hook
import
ParamTracerHook
from
colossalai.gemini.tensor_utils
import
free_storage
from
colossalai.nn.parallel.data_parallel
import
_cast_float
from
functools
import
partial
__all__
=
[
'ParamTracerWrapper'
]
...
...
@@ -13,17 +15,21 @@ class ParamTracerWrapper():
super
().
__init__
()
self
.
module
=
module
self
.
dtype
=
dtype
self
.
param_op_hook
=
ParamTracerHook
()
self
.
param_op_hook
=
ParamTracerHook
(
dtype
)
for
p
in
module
.
parameters
():
assert
isinstance
(
p
,
ColoParameter
)
p
.
data
=
p
.
data
.
to
(
dtype
)
if
p
.
requires_grad
:
p
.
register_hook
(
partial
(
self
.
grad_handle
))
self
.
_cast_buffers_to_cuda_dtype
()
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
def
grad_handle
(
self
,
grad
):
free_storage
(
grad
)
def
_pre_forward
(
self
):
self
.
param_op_hook
.
mem_monitor
.
start
()
...
...
colossalai/gemini/ophooks/param_trace_hook.py
View file @
6a9158f1
...
...
@@ -7,6 +7,7 @@ import torch
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.tensor.param_op_hook
import
ParamOpHook
from
colossalai.gemini.tensor_utils
import
free_storage
,
alloc_storage
class
TrainingPhase
(
Enum
):
...
...
@@ -16,25 +17,26 @@ class TrainingPhase(Enum):
class
ParamTracerHook
(
ParamOpHook
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
dtype
:
torch
.
dtype
=
torch
.
half
)
->
None
:
super
().
__init__
()
self
.
_training_phase
=
TrainingPhase
.
FORWARD
self
.
mem_monitor
=
SyncCudaMemoryMonitor
()
self
.
_non_model_data_list
=
[]
self
.
_model_data_list
=
[]
self
.
dtype
=
dtype
def
_move_params_to_dev
(
self
,
params
,
dev
:
str
)
->
int
:
assert
isinstance
(
dev
,
str
),
f
"device should be a str not torch.device"
comm_volume
=
0
def
_free_cuda_params
(
self
,
params
):
for
p
in
params
:
if
p
.
data
.
device
.
type
!=
dev
:
p
.
data
=
p
.
data
.
to
(
dev
)
comm_volume
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
if
p
.
grad
is
not
None
:
if
p
.
grad
.
device
.
type
!=
dev
:
p
.
grad
=
p
.
grad
.
to
(
dev
)
comm_volume
+=
p
.
grad
.
numel
()
*
p
.
grad
.
element_size
()
return
comm_volume
free_storage
(
p
.
data
)
def
_allocate_params_on_cuda
(
self
,
params
):
for
p
in
params
:
cur_dev
=
p
.
data
.
device
.
type
if
cur_dev
==
"cpu"
:
# p.data = p.data.to("cuda")
p
.
data
=
torch
.
randn
(
p
.
data
.
shape
,
device
=
"cuda"
,
dtype
=
self
.
dtype
)
elif
cur_dev
==
"cuda"
:
alloc_storage
(
p
.
data
)
def
sample_model_data
(
self
,
params
):
data_volume
=
0
...
...
@@ -49,12 +51,12 @@ class ParamTracerHook(ParamOpHook):
cuda_volume
=
self
.
mem_monitor
.
finish
()
if
len
(
self
.
_model_data_list
):
self
.
_non_model_data_list
.
append
(
cuda_volume
-
self
.
_model_data_list
[
-
1
])
self
.
_
mov
e_params_
to_dev
(
params
,
'cuda'
)
self
.
_
allocat
e_params_
on_cuda
(
params
)
self
.
sample_model_data
(
params
)
self
.
mem_monitor
.
start
()
def
post_op
(
self
,
params
):
self
.
_
move_params_to_dev
(
params
,
'cpu'
)
self
.
_
free_cuda_params
(
params
)
def
pre_forward
(
self
,
params
:
List
[
torch
.
Tensor
])
->
None
:
self
.
pre_op
(
params
)
...
...
colossalai/gemini/tensor_utils.py
View file @
6a9158f1
...
...
@@ -3,6 +3,20 @@ from colossalai.gemini.stateful_tensor import StatefulTensor
from
typing
import
Union
,
Tuple
def
is_storage_empty
(
tensor
:
torch
.
Tensor
)
->
bool
:
return
tensor
.
storage
().
size
()
==
0
def
free_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
not
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
0
)
def
alloc_storage
(
tensor
:
torch
.
Tensor
)
->
None
:
if
is_storage_empty
(
tensor
):
tensor
.
storage
().
resize_
(
tensor
.
numel
())
def
colo_tensor_mem_usage
(
tensor
:
Union
[
torch
.
Tensor
,
StatefulTensor
])
->
Tuple
[
int
,
int
]:
if
isinstance
(
tensor
,
StatefulTensor
):
t
=
tensor
.
payload
...
...
tests/test_gemini/test_param_tracer.py
View file @
6a9158f1
...
...
@@ -16,7 +16,7 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
model
.
backward
(
loss
)
def
run_param_wrapper_testing
():
test_models
=
[
'simple_net'
]
test_models
=
[
'simple_net'
,
'repeated_computed_layers'
,
'nested_model'
]
for
model_name
in
test_models
:
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment