Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d86ddd9b
Unverified
Commit
d86ddd9b
authored
Aug 11, 2023
by
LuGY
Committed by
GitHub
Aug 11, 2023
Browse files
[hotfix] fix unsafe async comm in zero (#4404)
* improve stablility of zero * fix wrong index * add record stream
parent
6ccecc0c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
20 deletions
+46
-20
colossalai/zero/low_level/bookkeeping/bucket_store.py
colossalai/zero/low_level/bookkeeping/bucket_store.py
+36
-19
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+9
-0
tests/test_zero/test_low_level/test_zero1_2.py
tests/test_zero/test_low_level/test_zero1_2.py
+1
-1
No files found.
colossalai/zero/low_level/bookkeeping/bucket_store.py
View file @
d86ddd9b
...
@@ -13,15 +13,20 @@ class BucketStore(BaseStore):
...
@@ -13,15 +13,20 @@ class BucketStore(BaseStore):
def
__init__
(
self
,
torch_pg
:
ProcessGroup
):
def
__init__
(
self
,
torch_pg
:
ProcessGroup
):
super
().
__init__
(
torch_pg
)
super
().
__init__
(
torch_pg
)
# init
and reset
# init
self
.
current_group_id
=
0
self
.
current_group_id
=
0
self
.
_num_elements_in_bucket
=
0
# mapping gardient slices and parameter
# mapping gardient slices and parameter
self
.
grad_to_param_mapping
=
dict
()
self
.
grad_to_param_mapping
=
dict
()
self
.
_grad_in_bucket
=
dict
()
self
.
_param_list
=
[]
self
.
_param_list
=
[]
self
.
_padding_size
=
[]
self
.
_padding_size
=
[]
for
rank
in
range
(
self
.
_world_size
):
self
.
_grad_in_bucket
[
rank
]
=
[]
self
.
reset
()
# offset_list records number of tensors in the bucket before each reduction
self
.
offset_list
=
[
0
]
def
num_elements_in_bucket
(
self
)
->
int
:
def
num_elements_in_bucket
(
self
)
->
int
:
"""Return the total number of elements in bucket
"""Return the total number of elements in bucket
...
@@ -32,6 +37,12 @@ class BucketStore(BaseStore):
...
@@ -32,6 +37,12 @@ class BucketStore(BaseStore):
return
self
.
_num_elements_in_bucket
return
self
.
_num_elements_in_bucket
def
reset_num_elements_in_bucket
(
self
):
"""Set the number of elements in bucket to zero.
"""
self
.
_num_elements_in_bucket
=
0
def
add_param_grad
(
self
,
group_id
:
int
,
param
:
Tensor
,
padding_size
:
int
):
def
add_param_grad
(
self
,
group_id
:
int
,
param
:
Tensor
,
padding_size
:
int
):
"""Add a param to bucket and record the padding size of a param for gradient padding
"""Add a param to bucket and record the padding size of a param for gradient padding
...
@@ -46,28 +57,32 @@ class BucketStore(BaseStore):
...
@@ -46,28 +57,32 @@ class BucketStore(BaseStore):
self
.
_num_elements_in_bucket
+=
(
param
.
numel
()
+
padding_size
)
self
.
_num_elements_in_bucket
+=
(
param
.
numel
()
+
padding_size
)
self
.
current_group_id
=
group_id
self
.
current_group_id
=
group_id
# number of tensors in current bucket
self
.
offset_list
[
-
1
]
+=
1
def
build_grad_in_bucket
(
self
):
def
build_grad_in_bucket
(
self
):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method
Data structure of self._grad_in_bucket:
Data structure of self._grad_in_bucket:
{
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad
1
_rank1, grad1_rank1, ...]
rank1: [grad
0
_rank1, grad1_rank1, ...]
}
}
"""
"""
for
param
,
padding_size
in
zip
(
self
.
_param_list
,
self
.
_padding_size
):
for
param
,
padding_size
in
zip
(
self
.
_param_list
,
self
.
_padding_size
):
with
torch
.
no_grad
():
grad
=
param
.
grad
.
clone
().
detach
().
flatten
()
grad
=
param
.
grad
.
detach
().
flatten
()
if
padding_size
>
0
:
if
padding_size
>
0
:
grad
=
torch
.
nn
.
functional
.
pad
(
grad
,
[
0
,
padding_size
])
with
torch
.
no_grad
():
grad
=
torch
.
nn
.
functional
.
pad
(
grad
.
view
(
-
1
),
[
0
,
padding_size
])
grad_list
=
grad
.
split
(
grad
.
numel
()
//
self
.
_world_size
)
grad_list
=
grad
.
split
(
grad
.
numel
()
//
self
.
_world_size
)
for
rank
in
range
(
self
.
_world_size
):
for
rank
in
range
(
self
.
_world_size
):
grad_current_rank
=
grad_list
[
rank
].
detach
()
grad_current_rank
=
grad_list
[
rank
].
clone
().
detach
()
self
.
grad_to_param_mapping
[
id
(
grad_current_rank
)]
=
id
(
param
)
self
.
grad_to_param_mapping
[
id
(
grad_current_rank
)]
=
id
(
param
)
self
.
_grad_in_bucket
[
rank
].
append
(
grad_current_rank
)
self
.
_grad_in_bucket
[
rank
].
append
(
grad_current_rank
)
param
.
grad
=
None
param
.
grad
=
None
self
.
offset_list
.
append
(
0
)
def
get_grad
(
self
)
->
Dict
:
def
get_grad
(
self
)
->
Dict
:
"""Return the dictionary of gradients slices, of which the keys are ranks
"""Return the dictionary of gradients slices, of which the keys are ranks
...
@@ -104,10 +119,12 @@ class BucketStore(BaseStore):
...
@@ -104,10 +119,12 @@ class BucketStore(BaseStore):
return
self
.
grad_to_param_mapping
[
id
(
grad
)]
return
self
.
grad_to_param_mapping
[
id
(
grad
)]
def
reset
(
self
):
def
reset
(
self
):
self
.
grad_to_param_mapping
=
dict
()
"""Reset the bucket storage after reduction, only release the tensors have been reduced
self
.
_num_elements_in_bucket
=
0
"""
self
.
_param_list
=
[]
cur_offset
=
self
.
offset_list
.
pop
(
0
)
self
.
_padding_size
=
[]
self
.
_param_list
=
self
.
_param_list
[
cur_offset
:]
self
.
_grad_in_bucket
=
dict
()
self
.
_padding_size
=
self
.
_padding_size
[
cur_offset
:]
for
_
in
range
(
cur_offset
):
del
self
.
grad_to_param_mapping
[
next
(
iter
(
self
.
grad_to_param_mapping
))]
for
rank
in
range
(
self
.
_world_size
):
for
rank
in
range
(
self
.
_world_size
):
self
.
_grad_in_bucket
[
rank
]
=
[
]
self
.
_grad_in_bucket
[
rank
]
=
self
.
_grad_in_bucket
[
rank
][
cur_offset
:
]
colossalai/zero/low_level/low_level_optim.py
View file @
d86ddd9b
...
@@ -242,10 +242,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -242,10 +242,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def
_run_reduction
(
self
):
def
_run_reduction
(
self
):
if
self
.
_bucket_store
.
num_elements_in_bucket
()
>
0
:
if
self
.
_bucket_store
.
num_elements_in_bucket
()
>
0
:
self
.
_bucket_store
.
build_grad_in_bucket
()
self
.
_bucket_store
.
build_grad_in_bucket
()
flat_grads
=
self
.
_bucket_store
.
get_flatten_grad
()
flat_grads
=
self
.
_bucket_store
.
get_flatten_grad
()
flat_grads
/=
self
.
_world_size
flat_grads
/=
self
.
_world_size
# ready to add other tensors to bucket
self
.
_bucket_store
.
reset_num_elements_in_bucket
()
if
self
.
_overlap_communication
:
if
self
.
_overlap_communication
:
stream
=
self
.
_comm_stream
stream
=
self
.
_comm_stream
# in case of the memory being reused in the default stream
flat_grads
.
record_stream
(
stream
)
# waiting for ops in the default stream finishing
stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
else
:
else
:
stream
=
torch
.
cuda
.
current_stream
()
stream
=
torch
.
cuda
.
current_stream
()
...
...
tests/test_zero/test_low_level/test_zero1_2.py
View file @
d86ddd9b
...
@@ -137,7 +137,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
...
@@ -137,7 +137,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype):
zero_optimizer
=
LowLevelZeroOptimizer
(
zero_optimizer
,
zero_optimizer
=
LowLevelZeroOptimizer
(
zero_optimizer
,
overlap_communication
=
True
,
overlap_communication
=
True
,
initial_scale
=
1
,
initial_scale
=
1
,
reduce_bucket_size
=
26214
4
)
reduce_bucket_size
=
1024
*
102
4
)
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
)
torch_optimizer
=
torch
.
optim
.
SGD
(
torch_model
.
parameters
(),
lr
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment