Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7066dfbf
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "a445e118cfdf47a8fd1c915b7eee6faa5b883f3d"
Unverified
Commit
7066dfbf
authored
Nov 16, 2022
by
HELSON
Committed by
GitHub
Nov 16, 2022
Browse files
[zero] fix memory leak for zero2 (#1955)
parent
60abd86d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
171 additions
and
9 deletions
+171
-9
colossalai/zero/sharded_optim/low_level_optim.py
colossalai/zero/sharded_optim/low_level_optim.py
+10
-9
tests/test_zero/low_level_zero/test_grad_clip.py
tests/test_zero/low_level_zero/test_grad_clip.py
+161
-0
No files found.
colossalai/zero/sharded_optim/low_level_optim.py
View file @
7066dfbf
...
...
@@ -48,7 +48,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
verbose
=
False
,
# communication
reduce_bucket_size
=
50000000
0
,
reduce_bucket_size
=
50000000
,
communication_dtype
=
torch
.
float16
,
overlap_communication
=
False
,
...
...
@@ -125,14 +125,14 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for
group_id
,
param_group
in
enumerate
(
self
.
_optimizer
.
param_groups
):
params
=
param_group
[
'params'
]
group_
params
=
param_group
[
'params'
]
# add the fp16 params to fp16_param_groups for bookkeeping
self
.
_fp16_param_groups
[
group_id
]
=
params
self
.
_fp16_param_groups
[
group_id
]
=
group_
params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank
=
self
.
_partition_param_list
(
params
)
params_per_rank
=
self
.
_partition_param_list
(
group_
params
)
# store the mapping between param to rank
# each param should belong to only one rank
...
...
@@ -143,14 +143,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# move to cpu to make room to create the flat tensor
# move_tensor(params, device='cpu')
for
param
in
params
:
for
param
in
group_
params
:
param
.
data
=
param
.
data
.
cpu
()
# flatten the reordered tensors
for
rank
in
range
(
self
.
_world_size
):
tensor_list
=
self
.
_param_store
.
get_fp16_params_by_rank_group
(
rank
,
group_id
)
flat_tensor
=
flatten
(
tensor_list
)
flat_tensor
=
flat_tensor
.
cuda
()
with
torch
.
no_grad
():
flat_tensor
=
flatten
(
tensor_list
)
flat_tensor
=
flat_tensor
.
data
.
cuda
()
self
.
_param_store
.
add_flat_fp16_param_by_rank_group
(
rank
,
group_id
,
flat_tensor
)
# sync parameters
...
...
@@ -161,7 +162,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# create a copy of fp32 weights of the parameters for which this rank is responsible
fp16_flat_current_rank
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
self
.
_local_rank
,
group_id
)
fp32_flat_current_rank
=
fp16_flat_current_rank
.
clone
().
float
().
detach
()
fp32_flat_current_rank
=
fp16_flat_current_rank
.
float
()
device
=
'cpu'
if
self
.
_cpu_offload
else
get_current_device
()
fp32_flat_current_rank
=
fp32_flat_current_rank
.
to
(
device
)
fp32_flat_current_rank
.
requires_grad
=
True
...
...
@@ -384,7 +385,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# torch.optim.Optimizer methods
################################
def
backward
(
self
,
loss
,
retain_graph
=
Tru
e
):
def
backward
(
self
,
loss
,
retain_graph
=
Fals
e
):
loss
=
self
.
loss_scale
*
loss
loss
.
backward
(
retain_graph
=
retain_graph
)
...
...
tests/test_zero/low_level_zero/test_grad_clip.py
0 → 100644
View file @
7066dfbf
import
copy
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
from
colossalai.utils
import
free_port
from
colossalai.zero
import
LowLevelZeroOptimizer
def
check_equal
(
a
,
b
,
rtol
=
1e-4
,
atol
=
1e-3
):
"""
This function checks if two tensors are equal within tolerance
"""
assert
torch
.
allclose
(
a
.
float
(),
b
.
float
(),
rtol
=
rtol
,
atol
=
atol
),
f
'a =
{
a
}
, b =
{
b
}
'
def
check_completely_equal
(
a
,
b
):
"""
This function checks if two tensors are completely equal
"""
assert
torch
.
all
(
a
==
b
),
f
'a =
{
a
}
, b =
{
b
}
'
class
TestModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TestModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
128
,
256
)
self
.
linear2
=
nn
.
Linear
(
256
,
512
)
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
return
x
def
exam_zero_1_2_grad_clip
():
# create model
zero1_model
=
TestModel
().
cuda
().
half
()
zero2_model
=
copy
.
deepcopy
(
zero1_model
)
# create optimizer
zero1_optimizer
=
torch
.
optim
.
Adam
(
zero1_model
.
parameters
(),
lr
=
0.001
)
zero2_optimizer
=
torch
.
optim
.
Adam
(
zero2_model
.
parameters
(),
lr
=
0.001
)
zero1_optimizer
=
LowLevelZeroOptimizer
(
zero1_optimizer
,
overlap_communication
=
True
,
initial_scale
=
32
,
clip_grad_norm
=
1.0
,
verbose
=
True
)
zero2_optimizer
=
LowLevelZeroOptimizer
(
zero2_optimizer
,
overlap_communication
=
True
,
partition_grad
=
True
,
initial_scale
=
32
,
clip_grad_norm
=
1.0
)
# create
input_data
=
torch
.
rand
(
32
,
128
).
cuda
().
half
()
# forward
zero1_output
=
zero1_model
(
input_data
)
zero2_output
=
zero2_model
(
input_data
)
check_completely_equal
(
zero1_output
,
zero2_output
)
# backward
zero1_optimizer
.
backward
(
zero1_output
.
mean
().
float
())
zero2_optimizer
.
backward
(
zero2_output
.
mean
().
float
())
# check grad
# as this param is small, the backward reduction
# will not be fired
for
z1p
,
z2p
in
zip
(
zero1_model
.
parameters
(),
zero2_model
.
parameters
()):
check_completely_equal
(
z1p
.
grad
,
z2p
.
grad
)
# step
zero1_optimizer
.
sync_grad
()
zero2_optimizer
.
sync_grad
()
# step
zero1_optimizer
.
step
()
zero2_optimizer
.
step
()
# check updated param
for
z1p
,
z2p
in
zip
(
zero1_model
.
parameters
(),
zero2_model
.
parameters
()):
check_completely_equal
(
z1p
.
data
,
z2p
.
data
)
def
exam_zero_1_grad_clip
():
# create models
zero_model
=
TestModel
()
torch_model
=
copy
.
deepcopy
(
zero_model
)
zero_model
=
zero_model
.
cuda
().
half
()
torch_model
=
DDP
(
torch_model
.
cuda
())
# create optimizer
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
(),
lr
=
0.001
)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer
=
LowLevelZeroOptimizer
(
zero_optimizer
,
overlap_communication
=
True
,
initial_scale
=
1
,
clip_grad_norm
=
1.0
)
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
0.001
)
# create
input_data
=
torch
.
rand
(
32
,
128
).
cuda
()
# zero-dp forward
zero_output
=
zero_model
(
input_data
.
half
())
# torch-ddp forward
torch_output
=
torch_model
(
input_data
)
check_equal
(
zero_output
,
torch_output
)
# zero-dp backward
zero_optimizer
.
backward
(
zero_output
.
mean
().
float
())
# torch-ddp backward
torch_output
.
mean
().
backward
()
# check grad
for
p
,
z1p
in
zip
(
torch_model
.
parameters
(),
zero_model
.
parameters
()):
check_equal
(
p
.
grad
,
z1p
.
grad
)
# zero-dp step
zero_optimizer
.
sync_grad
()
zero_optimizer
.
step
()
# torch ddp step
torch
.
nn
.
utils
.
clip_grad_norm_
(
torch_model
.
parameters
(),
1.0
)
torch_optimizer
.
step
()
# check updated param
for
p
,
z1p
in
zip
(
torch_model
.
parameters
(),
zero_model
.
parameters
()):
check_equal
(
p
.
data
,
z1p
.
data
,
atol
=
5e-4
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
exam_zero_1_2_grad_clip
()
exam_zero_1_grad_clip
()
@
pytest
.
mark
.
dist
def
test_grad_clip
():
world_size
=
2
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_grad_clip
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment