Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
55ed1057
Unverified
Commit
55ed1057
authored
Sep 15, 2020
by
Jeff Rasley
Committed by
GitHub
Sep 15, 2020
Browse files
fix bug related to stitching reduced grads across communication partitions (#318)
parent
91b4a93d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
21 deletions
+13
-21
deepspeed/runtime/zero/stage1.py
deepspeed/runtime/zero/stage1.py
+13
-21
No files found.
deepspeed/runtime/zero/stage1.py
View file @
55ed1057
...
...
@@ -249,8 +249,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition
,
params_in_rank_sub_partitions_offsets
,
\
params_not_local
=
self
.
get_all_sub_partition_info
(
params_in_rank_sub_partition
,
params_in_rank_sub_partitions_offsets
,
params_not_local
=
self
.
get_all_sub_partition_info
(
tensor_list
=
self
.
fp16_groups
[
i
],
all_element_intervals
=
element_intervals
,
local_rank
=
local_rank
,
...
...
@@ -591,28 +590,20 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_comm_partitions
.
append
(
single_comm_all_partitions
)
for
p
in
my_params
:
partitions
=
param_partition_map
[
p
]
parts
=
[]
for
part
in
partitions
:
params
,
offsets
=
partition_param_map
[
part
]
found
=
False
for
p_idx
,
_p
in
enumerate
(
params
):
if
p
.
__hash__
()
==
_p
.
__hash__
():
found
=
True
if
offsets
[
p_idx
][
0
]
is
not
None
:
my_part
=
part
.
narrow
(
0
,
offsets
[
p_idx
][
0
],
offsets
[
p_idx
][
1
])
parts
.
append
(
my_part
)
assert
found
if
p
is
not
None
:
updated_grad
=
_unflatten_dense_tensors
(
torch
.
cat
(
parts
),
[
p
])
p
.
grad
.
copy_
(
updated_grad
[
0
])
# stitch together all rank sub partitions for each comm idx
flat_comm_grads
=
[]
for
comm_idx
,
rank_partitions
in
enumerate
(
all_comm_partitions
):
flat_comm_grads
.
append
(
torch
.
cat
(
rank_partitions
))
flat_all_grads
=
torch
.
cat
(
flat_comm_grads
)
# copy back reduced gradients but only those needed for this local rank
for
param
,
updated_grad
in
zip
(
self
.
fp16_groups
[
i
],
_unflatten_dense_tensors
(
flat_all_grads
,
self
.
fp16_groups
[
i
])):
if
param
in
my_params
:
param
.
grad
.
copy_
(
updated_grad
)
def
step
(
self
,
closure
=
None
):
# First compute norm for all group so we know if there is overflow
self
.
overflow
=
self
.
overflow_checker
.
check
()
prev_scale
=
self
.
loss_scale
...
...
@@ -649,6 +640,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
#)
#TODO RS: can we safely use dtype of the first sub-partition? i think so
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions
=
self
.
get_flat_sub_partitions
(
comm_tensor_list
=
self
.
params_in_rank_sub_partitions
[
i
][
partition_id
],
comm_param_offsets
=
self
.
params_in_rank_sub_partitions_offsets
[
i
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment