Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
6e65c2cc
Unverified
Commit
6e65c2cc
authored
Nov 24, 2020
by
Olatunji Ruwase
Committed by
GitHub
Nov 24, 2020
Browse files
Deprecate client ability to disable gradient reduction (#552)
Co-authored-by:
Jeff Rasley
<
jerasley@microsoft.com
>
parent
1ef5cd23
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
1 deletion
+12
-1
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+6
-1
deepspeed/runtime/zero/stage2.py
deepspeed/runtime/zero/stage2.py
+6
-0
No files found.
deepspeed/runtime/engine.py
View file @
6e65c2cc
...
...
@@ -878,6 +878,11 @@ class DeepSpeedEngine(Module):
allreduce_gradients: If this is False, then gradient averaging will be skipped. Default is True.
"""
if
not
allreduce_gradients
:
logger
.
warning
(
f
'Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed'
)
# scale loss w.r.t. gradient accumulation if needed
if
self
.
gradient_accumulation_steps
()
>
1
:
loss
=
self
.
_scale_loss
(
loss
.
float
())
...
...
@@ -931,7 +936,7 @@ class DeepSpeedEngine(Module):
self
.
timers
(
'backward_allreduce_microstep'
).
start
()
self
.
timers
(
'backward_allreduce'
).
start
()
if
allreduce_gradients
and
self
.
enable_backward_allreduce
:
if
self
.
enable_backward_allreduce
:
self
.
allreduce_gradients
()
if
self
.
wall_clock_breakdown
():
...
...
deepspeed/runtime/zero/stage2.py
View file @
6e65c2cc
...
...
@@ -955,6 +955,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
with
torch
.
cuda
.
stream
(
stream
):
for
_
,
param
,
param_id
in
self
.
params_in_ipg_bucket
:
assert
self
.
params_already_reduced
[
param_id
]
==
False
,
\
f
"The parameter
{
param_id
}
has already been reduced.
\
Gradient computed twice for this partition.
\
Multiple gradient reduction is currently not supported"
self
.
params_already_reduced
[
param_id
]
=
True
if
not
self
.
is_param_in_current_partition
[
param_id
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment