Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
9f8e8f38
Unverified
Commit
9f8e8f38
authored
Dec 14, 2020
by
Stas Bekman
Committed by
GitHub
Dec 14, 2020
Browse files
implement missing get_last_lr (#595)
Co-authored-by:
Jeff Rasley
<
jerasley@microsoft.com
>
parent
c5a449f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
0 deletions
+21
-0
deepspeed/runtime/lr_schedules.py
deepspeed/runtime/lr_schedules.py
+21
-0
No files found.
deepspeed/runtime/lr_schedules.py
View file @
9f8e8f38
...
@@ -381,6 +381,12 @@ class LRRangeTest(object):
...
@@ -381,6 +381,12 @@ class LRRangeTest(object):
lr_range_test_min_lr
*
lr_increase
for
lr_range_test_min_lr
in
self
.
min_lr
lr_range_test_min_lr
*
lr_increase
for
lr_range_test_min_lr
in
self
.
min_lr
]
]
def
get_last_lr
(
self
):
""" Return last computed learning rate by current scheduler.
"""
assert
getattr
(
self
,
'_last_lr'
,
None
)
is
not
None
,
"need to call step() first"
return
self
.
_last_lr
def
_update_optimizer
(
self
,
group_lrs
):
def
_update_optimizer
(
self
,
group_lrs
):
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
group_lrs
):
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
group_lrs
):
param_group
[
'lr'
]
=
lr
param_group
[
'lr'
]
=
lr
...
@@ -390,6 +396,7 @@ class LRRangeTest(object):
...
@@ -390,6 +396,7 @@ class LRRangeTest(object):
batch_iteration
=
self
.
last_batch_iteration
+
1
batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
batch_iteration
self
.
last_batch_iteration
=
batch_iteration
self
.
_update_optimizer
(
self
.
get_lr
())
self
.
_update_optimizer
(
self
.
get_lr
())
self
.
_last_lr
=
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
state_dict
(
self
):
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
...
@@ -628,12 +635,19 @@ class OneCycle(object):
...
@@ -628,12 +635,19 @@ class OneCycle(object):
return
self
.
_get_cycle_lr
()
return
self
.
_get_cycle_lr
()
return
self
.
_get_decay_lr
(
self
.
last_batch_iteration
-
self
.
total_size
)
return
self
.
_get_decay_lr
(
self
.
last_batch_iteration
-
self
.
total_size
)
def
get_last_lr
(
self
):
""" Return last computed learning rate by current scheduler.
"""
assert
getattr
(
self
,
'_last_lr'
,
None
)
is
not
None
,
"need to call step() first"
return
self
.
_last_lr
def
step
(
self
,
batch_iteration
=
None
):
def
step
(
self
,
batch_iteration
=
None
):
if
batch_iteration
is
None
:
if
batch_iteration
is
None
:
batch_iteration
=
self
.
last_batch_iteration
+
1
batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
batch_iteration
self
.
last_batch_iteration
=
batch_iteration
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
param_group
[
'lr'
]
=
lr
param_group
[
'lr'
]
=
lr
self
.
_last_lr
=
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
state_dict
(
self
):
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
...
@@ -690,12 +704,19 @@ class WarmupLR(object):
...
@@ -690,12 +704,19 @@ class WarmupLR(object):
self
.
delta_lrs
)
self
.
delta_lrs
)
]
]
def
get_last_lr
(
self
):
""" Return last computed learning rate by current scheduler.
"""
assert
getattr
(
self
,
'_last_lr'
,
None
)
is
not
None
,
"need to call step() first"
return
self
.
_last_lr
def
step
(
self
,
last_batch_iteration
=
None
):
def
step
(
self
,
last_batch_iteration
=
None
):
if
last_batch_iteration
is
None
:
if
last_batch_iteration
is
None
:
last_batch_iteration
=
self
.
last_batch_iteration
+
1
last_batch_iteration
=
self
.
last_batch_iteration
+
1
self
.
last_batch_iteration
=
last_batch_iteration
self
.
last_batch_iteration
=
last_batch_iteration
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
for
param_group
,
lr
in
zip
(
self
.
optimizer
.
param_groups
,
self
.
get_lr
()):
param_group
[
'lr'
]
=
lr
param_group
[
'lr'
]
=
lr
self
.
_last_lr
=
[
group
[
'lr'
]
for
group
in
self
.
optimizer
.
param_groups
]
def
state_dict
(
self
):
def
state_dict
(
self
):
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
return
{
'last_batch_iteration'
:
self
.
last_batch_iteration
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment