Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
16193619
Commit
16193619
authored
Nov 29, 2020
by
mohammad
Browse files
added refactored learning rate
parent
65290033
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
14 deletions
+39
-14
megatron/learning_rates.py
megatron/learning_rates.py
+39
-14
No files found.
megatron/learning_rates.py
View file @
16193619
...
@@ -19,7 +19,6 @@ import math
...
@@ -19,7 +19,6 @@ import math
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
class
AnnealingLR
(
object
):
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
"""Anneals the learning rate."""
...
@@ -31,44 +30,67 @@ class AnnealingLR(object):
...
@@ -31,44 +30,67 @@ class AnnealingLR(object):
# Class values.
# Class values.
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
start_lr
=
start_lr
self
.
start_lr
=
float
(
start_lr
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
assert
self
.
min_lr
>=
0.0
assert
self
.
start_lr
>=
self
.
min_lr
self
.
warmup_iter
=
warmup_iter
self
.
warmup_iter
=
warmup_iter
self
.
num_iters
=
last_iter
self
.
num_iters
=
last_iter
self
.
end_iter
=
total_iters
self
.
end_iter
=
total_iters
assert
self
.
end_iter
>
0
assert
self
.
end_iter
>
0
assert
self
.
warmup_iter
<
self
.
end_iter
self
.
decay_style
=
decay_style
self
.
decay_style
=
decay_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
if
self
.
override_lr_scheduler
:
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
'use-checkpoint are set.'
'use-checkpoint are set.'
# Set the learning rate
# Set the learning rate
self
.
step
(
self
.
num_iters
)
self
.
step
(
self
.
num_iters
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Learning rate decay functions from:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_
=
min
(
self
.
num_iters
,
self
.
end_iter
-
self
.
warmup_iter
)
# Use linear warmup for the initial part.
# Warmup.
if
self
.
warmup_iter
>
0
and
self
.
num_iters
<=
self
.
warmup_iter
:
if
self
.
warmup_iter
>
0
and
self
.
num_iters
<=
self
.
warmup_iter
:
return
float
(
self
.
start_lr
)
*
num_iters_
/
self
.
warmup_iter
return
self
.
start_lr
*
float
(
self
.
num_iters
)
/
\
float
(
self
.
warmup_iter
)
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
return
self
.
start_lr
# For any iterations larger than `self.end_iter`, use `self.min_lr`.
if
self
.
num_iters
>
self
.
end_iter
:
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
current_iter
=
self
.
num_iters
-
self
.
warmup_iter
decay_iters
=
self
.
end_iter
-
self
.
warmup_iter
decay_ratio
=
float
(
current_iter
)
/
float
(
decay_iters
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
start_lr
-
self
.
min_lr
num_iters_
=
num_iters_
-
self
.
warmup_iter
if
self
.
decay_style
==
'linear'
:
if
self
.
decay_style
==
'linear'
:
lr
=
self
.
start_lr
*
(
self
.
end_iter
-
num_iters_
)
/
self
.
end_iter
coeff
=
(
1.0
-
decay_ratio
)
elif
self
.
decay_style
==
'cosine'
:
elif
self
.
decay_style
==
'cosine'
:
lr
=
self
.
start_lr
/
2.0
*
(
math
.
cos
(
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
math
.
pi
*
num_iters_
/
self
.
end_iter
)
+
1
)
elif
self
.
decay_style
==
'exponential'
:
# exp(-0.693) = 1/2
lr
=
self
.
start_lr
*
math
.
exp
(
-
0.693
*
num_iters_
/
self
.
end_iter
)
else
:
else
:
lr
=
self
.
start_lr
raise
Exception
(
'{} decay style is not supported.'
.
format
(
return
max
(
lr
,
self
.
min_lr
)
self
.
decay_style
))
return
self
.
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
step_num
=
None
):
def
step
(
self
,
step_num
=
None
):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
...
@@ -79,6 +101,7 @@ class AnnealingLR(object):
...
@@ -79,6 +101,7 @@ class AnnealingLR(object):
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'start_lr'
:
self
.
start_lr
,
'start_lr'
:
self
.
start_lr
,
...
@@ -90,6 +113,7 @@ class AnnealingLR(object):
...
@@ -90,6 +113,7 @@ class AnnealingLR(object):
}
}
return
state_dict
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them."""
...
@@ -104,6 +128,7 @@ class AnnealingLR(object):
...
@@ -104,6 +128,7 @@ class AnnealingLR(object):
name
))
name
))
return
sd_value
return
sd_value
def
load_state_dict
(
self
,
sd
):
def
load_state_dict
(
self
,
sd
):
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment