Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
a4763f55
Unverified
Commit
a4763f55
authored
Dec 12, 2020
by
carefree0910
Committed by
GitHub
Dec 11, 2020
Browse files
Supported customizing kwargs for lr_scheduler (#584)
Co-authored-by:
Jeff Rasley
<
jerasley@microsoft.com
>
parent
66268bd3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+4
-4
deepspeed/runtime/pipe/engine.py
deepspeed/runtime/pipe/engine.py
+2
-2
No files found.
deepspeed/runtime/engine.py
View file @
a4763f55
...
@@ -979,7 +979,7 @@ class DeepSpeedEngine(Module):
...
@@ -979,7 +979,7 @@ class DeepSpeedEngine(Module):
torch
.
nn
.
utils
.
clip_grad_norm_
(
parameters
=
self
.
module
.
parameters
(),
torch
.
nn
.
utils
.
clip_grad_norm_
(
parameters
=
self
.
module
.
parameters
(),
max_norm
=
self
.
gradient_clipping
())
max_norm
=
self
.
gradient_clipping
())
def
_take_model_step
(
self
):
def
_take_model_step
(
self
,
lr_kwargs
):
if
self
.
gradient_clipping
()
>
0.0
:
if
self
.
gradient_clipping
()
>
0.0
:
if
not
self
.
fp16_enabled
()
and
not
self
.
amp_enabled
():
if
not
self
.
fp16_enabled
()
and
not
self
.
amp_enabled
():
self
.
clip_fp32_gradients
()
self
.
clip_fp32_gradients
()
...
@@ -1010,14 +1010,14 @@ class DeepSpeedEngine(Module):
...
@@ -1010,14 +1010,14 @@ class DeepSpeedEngine(Module):
self
.
skipped_steps
+=
1
self
.
skipped_steps
+=
1
else
:
else
:
if
self
.
lr_scheduler
is
not
None
:
if
self
.
lr_scheduler
is
not
None
:
self
.
lr_scheduler
.
step
()
self
.
lr_scheduler
.
step
(
**
(
lr_kwargs
or
{})
)
if
report_progress
and
(
self
.
global_steps
+
1
)
%
self
.
steps_per_print
()
==
0
:
if
report_progress
and
(
self
.
global_steps
+
1
)
%
self
.
steps_per_print
()
==
0
:
self
.
_report_progress
(
self
.
global_steps
+
1
)
self
.
_report_progress
(
self
.
global_steps
+
1
)
self
.
global_steps
+=
1
self
.
global_steps
+=
1
self
.
global_samples
+=
self
.
train_batch_size
()
self
.
global_samples
+=
self
.
train_batch_size
()
def
step
(
self
):
def
step
(
self
,
lr_kwargs
=
None
):
r
"""Execute the weight update step after forward and backward propagation
r
"""Execute the weight update step after forward and backward propagation
on effective_train_batch.
on effective_train_batch.
"""
"""
...
@@ -1034,7 +1034,7 @@ class DeepSpeedEngine(Module):
...
@@ -1034,7 +1034,7 @@ class DeepSpeedEngine(Module):
if
self
.
progressive_layer_drop
:
if
self
.
progressive_layer_drop
:
self
.
progressive_layer_drop
.
update_state
(
self
.
global_steps
)
self
.
progressive_layer_drop
.
update_state
(
self
.
global_steps
)
self
.
_take_model_step
()
self
.
_take_model_step
(
lr_kwargs
)
self
.
tput_timer
.
stop
(
report_progress
)
self
.
tput_timer
.
stop
(
report_progress
)
...
...
deepspeed/runtime/pipe/engine.py
View file @
a4763f55
...
@@ -940,14 +940,14 @@ class PipelineEngine(DeepSpeedEngine):
...
@@ -940,14 +940,14 @@ class PipelineEngine(DeepSpeedEngine):
if
self
.
wall_clock_breakdown
():
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'pipe_recv_grad'
).
stop
()
self
.
timers
(
'pipe_recv_grad'
).
stop
()
def
_exec_optimizer_step
(
self
):
def
_exec_optimizer_step
(
self
,
lr_kwargs
=
None
):
if
self
.
wall_clock_breakdown
():
if
self
.
wall_clock_breakdown
():
self
.
timers
(
'step_microstep'
).
start
()
self
.
timers
(
'step_microstep'
).
start
()
self
.
timers
(
'step'
).
start
()
self
.
timers
(
'step'
).
start
()
self
.
mem_status
(
'BEFORE STEP'
,
reset_max
=
True
)
self
.
mem_status
(
'BEFORE STEP'
,
reset_max
=
True
)
self
.
_force_grad_boundary
=
True
self
.
_force_grad_boundary
=
True
self
.
_take_model_step
()
self
.
_take_model_step
(
lr_kwargs
)
self
.
_force_grad_boundary
=
False
self
.
_force_grad_boundary
=
False
self
.
mem_status
(
'AFTER STEP'
)
self
.
mem_status
(
'AFTER STEP'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment