Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4c4388c4
Unverified
Commit
4c4388c4
authored
Apr 18, 2022
by
HELSON
Committed by
GitHub
Apr 18, 2022
Browse files
[hotfix] fix memory leak in zero (#781)
parent
4b01da24
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
36 deletions
+32
-36
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+2
-1
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+5
-4
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+21
-12
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+1
-6
tests/test_zero/test_stateful_tensor_mgr.py
tests/test_zero/test_stateful_tensor_mgr.py
+1
-11
No files found.
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
View file @
4c4388c4
...
@@ -12,7 +12,7 @@ __all__ = ['BaseGradScaler']
...
@@ -12,7 +12,7 @@ __all__ = ['BaseGradScaler']
class
BaseGradScaler
(
ABC
):
class
BaseGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
:
in
t
,
verbose
:
bool
):
def
__init__
(
self
,
initial_scale
:
floa
t
,
verbose
:
bool
):
assert
initial_scale
>
0
assert
initial_scale
>
0
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
self
.
_verbose
=
verbose
self
.
_verbose
=
verbose
...
@@ -31,6 +31,7 @@ class BaseGradScaler(ABC):
...
@@ -31,6 +31,7 @@ class BaseGradScaler(ABC):
def
state_dict
(
self
)
->
Dict
:
def
state_dict
(
self
)
->
Dict
:
state_dict
=
dict
()
state_dict
=
dict
()
state_dict
[
'scale'
]
=
self
.
scale
state_dict
[
'scale'
]
=
self
.
scale
return
state_dict
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
def
load_state_dict
(
self
,
state_dict
:
Dict
)
->
None
:
self
.
_scale
=
state_dict
[
'scale'
]
self
.
_scale
=
state_dict
[
'scale'
]
...
...
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
View file @
4c4388c4
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
torch
import
torch
from
.base_grad_scaler
import
BaseGradScaler
from
.base_grad_scaler
import
BaseGradScaler
from
typing
import
Optional
__all__
=
[
'DynamicGradScaler'
]
__all__
=
[
'DynamicGradScaler'
]
...
@@ -10,12 +11,12 @@ __all__ = ['DynamicGradScaler']
...
@@ -10,12 +11,12 @@ __all__ = ['DynamicGradScaler']
class
DynamicGradScaler
(
BaseGradScaler
):
class
DynamicGradScaler
(
BaseGradScaler
):
def
__init__
(
self
,
def
__init__
(
self
,
initial_scale
:
in
t
=
2
**
16
,
initial_scale
:
floa
t
=
2
**
16
,
growth_factor
:
in
t
=
2
,
growth_factor
:
floa
t
=
2
,
backoff_factor
:
float
=
0.5
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
growth_interval
:
int
=
1000
,
min_scale
:
int
=
None
,
min_scale
:
Optional
[
float
]
=
None
,
max_scale
:
int
=
None
,
max_scale
:
Optional
[
float
]
=
None
,
hysteresis
:
int
=
2
,
hysteresis
:
int
=
2
,
verbose
:
bool
=
False
):
verbose
:
bool
=
False
):
super
().
__init__
(
initial_scale
,
verbose
)
super
().
__init__
(
initial_scale
,
verbose
)
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
4c4388c4
...
@@ -358,8 +358,8 @@ class ShardedModelV2(nn.Module):
...
@@ -358,8 +358,8 @@ class ShardedModelV2(nn.Module):
assert
param
.
colo_attr
.
saved_grad
.
is_null
(
assert
param
.
colo_attr
.
saved_grad
.
is_null
(
),
'Gradien accumulation is not supported when reuse_fp16_shard=True'
),
'Gradien accumulation is not supported when reuse_fp16_shard=True'
param
.
colo_attr
.
reset_grad_payload
(
grad
)
param
.
colo_attr
.
reset_grad_payload
(
grad
.
data
)
param
.
colo_attr
.
reset_data_payload
(
grad
)
# release the memory of param
param
.
colo_attr
.
reset_data_payload
(
grad
.
data
)
# release the memory of param
if
param
.
colo_attr
.
is_replicated
:
if
param
.
colo_attr
.
is_replicated
:
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
4c4388c4
...
@@ -83,11 +83,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -83,11 +83,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
min_scale
:
float
=
1
,
min_scale
:
float
=
1
,
growth_factor
:
float
=
2
,
growth_factor
:
float
=
2
,
backoff_factor
:
float
=
0.5
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
floa
t
=
1000
,
growth_interval
:
in
t
=
1000
,
hysteresis
:
floa
t
=
2
,
hysteresis
:
in
t
=
2
,
max_scale
:
in
t
=
2
**
32
,
max_scale
:
floa
t
=
2
**
32
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
dp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
mp_process_group
:
Optional
[
ProcessGroup
]
=
None
,
verbose
:
bool
=
False
)
->
None
:
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
assert
isinstance
(
sharded_model
,
ShardedModelV2
),
'model must be wrapped with ShardedModel'
super
().
__init__
(
optimizer
)
super
().
__init__
(
optimizer
)
...
@@ -115,14 +116,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -115,14 +116,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
max_scale
=
max_scale
)
max_scale
=
max_scale
)
self
.
_found_overflow
:
Tensor
=
torch
.
IntTensor
([
0
]).
to
(
torch
.
cuda
.
current_device
())
self
.
_found_overflow
:
Tensor
=
torch
.
IntTensor
([
0
]).
to
(
torch
.
cuda
.
current_device
())
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
self
.
_logger
=
get_dist_logger
(
"ShardedOptimizerV2"
)
self
.
_verbose
=
verbose
# Store fp32 param shards
# Store fp32 param shards
self
.
_register_master_weight
()
self
.
_register_master_weight
()
if
self
.
gpu_margin_mem_ratio
!=
0.0
and
not
isinstance
(
sharded_model
.
_tensor_placement_policy
,
if
self
.
gpu_margin_mem_ratio
!=
0.0
and
not
isinstance
(
sharded_model
.
_tensor_placement_policy
,
AutoTensorPlacementPolicy
):
AutoTensorPlacementPolicy
):
self
.
_logger
.
warning
(
f
'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"'
)
self
.
_logger
.
warning
(
f
'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"'
)
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
if
self
.
_verbose
:
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
self
.
_use_memory_tracer
=
self
.
model
.
use_memory_tracer
self
.
_use_memory_tracer
=
self
.
model
.
use_memory_tracer
if
self
.
_use_memory_tracer
:
if
self
.
_use_memory_tracer
:
...
@@ -193,15 +197,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -193,15 +197,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_point_param_fp16_to_master_param
()
self
.
_point_param_fp16_to_master_param
()
if
self
.
_verbose
:
gpu_mem
,
cpu_mem
=
self
.
get_memory_usage
()
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
"Before step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
f
"Before step ShardedOptimizerV2 consumes
{
gpu_mem
/
1e6
}
MB CUDA Memory,
{
cpu_mem
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ranks
=
[
0
])
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
ret
=
self
.
optim
.
step
(
*
args
,
**
kwargs
)
if
self
.
_verbose
:
gpu_mem
,
cpu_mem
=
self
.
get_memory_usage
()
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
f
"After step ShardedOptimizerV2 consumes
{
gpu_mem
/
1e6
}
MB CUDA Memory,
{
cpu_mem
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ranks
=
[
0
])
self
.
_copy_master_model_to_model_fp16
()
self
.
_copy_master_model_to_model_fp16
()
return
ret
return
ret
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
4c4388c4
...
@@ -5,18 +5,13 @@ from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
...
@@ -5,18 +5,13 @@ from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from
.tensorful_state
import
StatefulTensor
,
TensorState
from
.tensorful_state
import
StatefulTensor
,
TensorState
from
typing
import
List
from
typing
import
List
# use this tensor as empty data point for parameters
# we do not want users use param.data when its torch payload is removed
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR
=
torch
.
BoolTensor
([],
device
=
'cpu'
)
EMPTY_TENSOR_DICT
=
{}
EMPTY_TENSOR_DICT
=
{}
def
get_empty_tensor
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
def
get_empty_tensor
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
key
=
(
device
,
dtype
)
key
=
(
device
,
dtype
)
if
key
not
in
EMPTY_TENSOR_DICT
:
if
key
not
in
EMPTY_TENSOR_DICT
:
EMPTY_TENSOR_DICT
[
key
]
=
FAKE_EMPTY_TENSOR
.
to
(
device
,
dtyp
e
)
EMPTY_TENSOR_DICT
[
key
]
=
torch
.
empty
(
0
,
dtype
=
dtype
,
device
=
devic
e
)
return
EMPTY_TENSOR_DICT
[
key
]
return
EMPTY_TENSOR_DICT
[
key
]
...
...
tests/test_zero/test_stateful_tensor_mgr.py
View file @
4c4388c4
...
@@ -72,23 +72,13 @@ def run_stm():
...
@@ -72,23 +72,13 @@ def run_stm():
# warmup done
# warmup done
# only 2 params can be on CUDA
# only 2 params can be on CUDA
limit_cuda_memory
(
0.26
)
limit_cuda_memory
(
0.26
/
tensor_placement_policy
.
_steady_cuda_cap_ratio
)
# use OPT-like eviction strategy
# use OPT-like eviction strategy
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p0
,
model
.
p1
],
stateful_tensor_mgr
)
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p2
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p2
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p0
,
[
model
.
p0
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_model_data
()
mem_collector
.
sample_overall_data
()
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
apply_adjust
(
model
,
model
.
p1
,
[
model
.
p1
,
model
.
p2
],
stateful_tensor_mgr
)
mem_collector
.
sample_model_data
()
mem_collector
.
finish_collection
()
def
apply_adjust
(
model
:
torch
.
nn
.
Module
,
compute_param
:
Parameter
,
cuda_param_after_adjust
:
List
[
Parameter
],
def
apply_adjust
(
model
:
torch
.
nn
.
Module
,
compute_param
:
Parameter
,
cuda_param_after_adjust
:
List
[
Parameter
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment