Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ff12df6b
Commit
ff12df6b
authored
Nov 29, 2020
by
mohammad
Browse files
refactored learning rate scheduler so addition of variable batch size is easier
parent
16193619
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
42 deletions
+60
-42
megatron/learning_rates.py
megatron/learning_rates.py
+55
-37
megatron/training.py
megatron/training.py
+5
-5
No files found.
megatron/learning_rates.py
View file @
ff12df6b
...
@@ -22,25 +22,25 @@ from megatron import print_rank_0
...
@@ -22,25 +22,25 @@ from megatron import print_rank_0
class
AnnealingLR
(
object
):
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
"""Anneals the learning rate."""
def
__init__
(
self
,
optimizer
,
start
_lr
,
def
__init__
(
self
,
optimizer
,
max_lr
,
min
_lr
,
warmup_
i
te
r
,
total_i
te
r
s
,
warmup_
s
te
ps
,
decay_s
te
p
s
,
decay_style
,
last_iter
,
min_lr
=
0.0
,
decay_style
,
num_steps
,
use_checkpoint_lr_scheduler
=
True
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
override_lr_scheduler
=
False
):
# Class values.
# Class values.
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
self
.
start
_lr
=
float
(
start
_lr
)
self
.
max
_lr
=
float
(
max
_lr
)
self
.
min_lr
=
min_lr
self
.
min_lr
=
min_lr
assert
self
.
min_lr
>=
0.0
assert
self
.
min_lr
>=
0.0
assert
self
.
start
_lr
>=
self
.
min_lr
assert
self
.
max
_lr
>=
self
.
min_lr
self
.
warmup_
i
te
r
=
warmup_
i
te
r
self
.
warmup_
s
te
ps
=
warmup_
s
te
ps
self
.
num_
i
te
r
s
=
last_iter
self
.
num_
s
te
p
s
=
num_steps
self
.
end_iter
=
total_i
te
r
s
self
.
decay_steps
=
decay_s
te
p
s
assert
self
.
end_iter
>
0
assert
self
.
decay_steps
>
0
assert
self
.
warmup_
i
te
r
<
self
.
end_iter
assert
self
.
warmup_
s
te
ps
<
self
.
decay_steps
self
.
decay_style
=
decay_style
self
.
decay_style
=
decay_style
...
@@ -51,7 +51,7 @@ class AnnealingLR(object):
...
@@ -51,7 +51,7 @@ class AnnealingLR(object):
'use-checkpoint are set.'
'use-checkpoint are set.'
# Set the learning rate
# Set the learning rate
self
.
step
(
self
.
num_
i
te
r
s
)
self
.
step
(
step_num
=
self
.
num_
s
te
p
s
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
...
@@ -61,25 +61,25 @@ class AnnealingLR(object):
...
@@ -61,25 +61,25 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
# Use linear warmup for the initial part.
if
self
.
warmup_
i
te
r
>
0
and
self
.
num_
i
te
r
s
<=
self
.
warmup_
i
te
r
:
if
self
.
warmup_
s
te
ps
>
0
and
self
.
num_
s
te
p
s
<=
self
.
warmup_
s
te
ps
:
return
self
.
start
_lr
*
float
(
self
.
num_
i
te
r
s
)
/
\
return
self
.
max
_lr
*
float
(
self
.
num_
s
te
p
s
)
/
\
float
(
self
.
warmup_
i
te
r
)
float
(
self
.
warmup_
s
te
ps
)
# If the learning rate is constant, just return the initial value.
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
if
self
.
decay_style
==
'constant'
:
return
self
.
start
_lr
return
self
.
max
_lr
# For any
i
te
ration
s larger than `self.
end_iter
`, use `self.min_lr`.
# For any
s
te
p
s larger than `self.
decay_steps
`, use `self.min_lr`.
if
self
.
num_
i
te
r
s
>
self
.
end_iter
:
if
self
.
num_
s
te
p
s
>
self
.
decay_steps
:
return
self
.
min_lr
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
# If we are done with the warmup period, use the decay style.
current_iter
=
self
.
num_
i
te
r
s
-
self
.
warmup_
i
te
r
num_steps_
=
self
.
num_
s
te
p
s
-
self
.
warmup_
s
te
ps
decay_
i
te
rs
=
self
.
end_iter
-
self
.
warmup_
i
te
r
decay_
s
te
ps_
=
self
.
decay_steps
-
self
.
warmup_
s
te
ps
decay_ratio
=
float
(
current_iter
)
/
float
(
decay_
i
te
rs
)
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_
s
te
ps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
start
_lr
-
self
.
min_lr
delta_lr
=
self
.
max
_lr
-
self
.
min_lr
if
self
.
decay_style
==
'linear'
:
if
self
.
decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
coeff
=
(
1.0
-
decay_ratio
)
...
@@ -92,11 +92,11 @@ class AnnealingLR(object):
...
@@ -92,11 +92,11 @@ class AnnealingLR(object):
return
self
.
min_lr
+
coeff
*
delta_lr
return
self
.
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
step_num
=
None
):
def
step
(
self
,
increment
=
1
,
step_num
=
None
):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
if
step_num
is
None
:
if
step_num
is
None
:
step_num
=
self
.
num_
i
te
r
s
+
1
step_num
=
self
.
num_
s
te
p
s
+
increment
self
.
num_
i
te
r
s
=
step_num
self
.
num_
s
te
p
s
=
step_num
new_lr
=
self
.
get_lr
()
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
...
@@ -104,11 +104,11 @@ class AnnealingLR(object):
...
@@ -104,11 +104,11 @@ class AnnealingLR(object):
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'
start
_lr'
:
self
.
start
_lr
,
'
max
_lr'
:
self
.
max
_lr
,
'warmup_
i
te
r
'
:
self
.
warmup_
i
te
r
,
'warmup_
s
te
ps
'
:
self
.
warmup_
s
te
ps
,
'num_
i
te
r
s'
:
self
.
num_
i
te
r
s
,
'num_
s
te
p
s'
:
self
.
num_
s
te
p
s
,
'decay_style'
:
self
.
decay_style
,
'decay_style'
:
self
.
decay_style
,
'
end_iter'
:
self
.
end_iter
,
'
decay_steps'
:
self
.
decay_steps
,
'min_lr'
:
self
.
min_lr
'min_lr'
:
self
.
min_lr
}
}
return
state_dict
return
state_dict
...
@@ -131,18 +131,36 @@ class AnnealingLR(object):
...
@@ -131,18 +131,36 @@ class AnnealingLR(object):
def
load_state_dict
(
self
,
sd
):
def
load_state_dict
(
self
,
sd
):
self
.
start_lr
=
self
.
_check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
if
'start_lr'
in
sd
:
'learning rate'
)
max_lr_
=
sd
[
'start_lr'
]
else
:
max_lr_
=
sd
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
'minimum learning rate'
)
'minimum learning rate'
)
self
.
warmup_iter
=
self
.
_check_and_set
(
self
.
warmup_iter
,
sd
[
'warmup_iter'
],
if
'warmup_iter'
in
sd
:
'warmup iterations'
)
warmup_steps_
=
sd
[
'warmup_iter'
]
self
.
end_iter
=
self
.
_check_and_set
(
self
.
end_iter
,
sd
[
'end_iter'
],
else
:
'total number of iterations'
)
warmup_steps_
=
sd
[
'warmup_steps'
]
self
.
warmup_steps
=
self
.
_check_and_set
(
self
.
warmup_steps
,
warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
sd
:
decay_steps_
=
sd
[
'end_iter'
]
else
:
decay_steps_
=
sd
[
'decay_steps'
]
self
.
decay_steps
=
self
.
_check_and_set
(
self
.
decay_steps
,
decay_steps_
,
'total number of iterations'
)
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
sd
[
'decay_style'
],
'decay style'
)
'decay style'
)
self
.
num_iters
=
sd
[
'num_iters'
]
if
'num_iters'
in
sd
:
self
.
step
(
self
.
num_iters
)
self
.
num_steps
=
sd
[
'num_iters'
]
else
:
self
.
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
step_num
=
self
.
num_steps
)
megatron/training.py
View file @
ff12df6b
...
@@ -196,12 +196,12 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -196,12 +196,12 @@ def get_learning_rate_scheduler(optimizer):
warmup_iter
=
args
.
warmup
*
num_iters
warmup_iter
=
args
.
warmup
*
num_iters
lr_scheduler
=
AnnealingLR
(
lr_scheduler
=
AnnealingLR
(
optimizer
,
optimizer
,
start_lr
=
args
.
lr
,
max_lr
=
args
.
lr
,
warmup_iter
=
warmup_iter
,
total_iters
=
num_iters
,
decay_style
=
args
.
lr_decay_style
,
last_iter
=
init_step
,
min_lr
=
args
.
min_lr
,
min_lr
=
args
.
min_lr
,
warmup_steps
=
warmup_iter
,
decay_steps
=
num_iters
,
decay_style
=
args
.
lr_decay_style
,
num_steps
=
init_step
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
override_lr_scheduler
=
args
.
override_lr_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment