Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
04ecc834
Commit
04ecc834
authored
Jan 28, 2022
by
Vijay Korthikanti
Browse files
changing class name AnnealingLR to OptimizerParamScheduler
parent
53931b8b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
60 additions
and
57 deletions
+60
-57
megatron/arguments.py
megatron/arguments.py
+2
-2
megatron/checkpointing.py
megatron/checkpointing.py
+9
-6
megatron/optimizer_param_scheduler.py
megatron/optimizer_param_scheduler.py
+9
-9
megatron/training.py
megatron/training.py
+23
-23
megatron/utils.py
megatron/utils.py
+2
-2
tasks/finetune_utils.py
tasks/finetune_utils.py
+8
-8
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+7
-7
No files found.
megatron/arguments.py
View file @
04ecc834
...
...
@@ -577,13 +577,13 @@ def _add_learning_rate_args(parser):
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
group
.
add_argument
(
'--override-
lr
-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--override-
opt_param
-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.'
)
group
.
add_argument
(
'--use-checkpoint-
lr
-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-checkpoint-
opt_param
-scheduler'
,
action
=
'store_true'
,
help
=
'Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
...
...
megatron/checkpointing.py
View file @
04ecc834
...
...
@@ -167,7 +167,7 @@ def get_rng_state():
return
rng_state_list
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
):
"""Save a model checkpoint."""
args
=
get_args
()
...
...
@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
lr
_scheduler
is
not
None
:
state_dict
[
'
lr
_scheduler'
]
=
lr
_scheduler
.
state_dict
()
if
opt_param
_scheduler
is
not
None
:
state_dict
[
'
opt_param
_scheduler'
]
=
opt_param
_scheduler
.
state_dict
()
# RNG states.
if
not
args
.
no_save_rng
:
...
...
@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0
(
" succesfully fixed query-key-values ordering for"
" checkpoint version {}"
.
format
(
checkpoint_version
))
def
load_checkpoint
(
model
,
optimizer
,
lr
_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
def
load_checkpoint
(
model
,
optimizer
,
opt_param
_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
...
...
@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
if
opt_param_scheduler
is
not
None
:
if
'lr_scheduler'
in
state_dict
:
# backward compatbility
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
else
:
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'opt_param_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
...
...
megatron/
learning_rates
.py
→
megatron/
optimizer_param_scheduler
.py
View file @
04ecc834
...
...
@@ -19,14 +19,14 @@ import math
from
megatron
import
print_rank_0
class
AnnealingLR
(
object
):
class
OptimizerParamScheduler
(
object
):
"""Anneals the learning rate."""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
start_wd
,
end_wd
,
wd_incr_style
,
use_checkpoint_
lr
_scheduler
=
True
,
override_
lr
_scheduler
=
False
):
use_checkpoint_
opt_param
_scheduler
=
True
,
override_
opt_param
_scheduler
=
False
):
# Class values.
self
.
optimizer
=
optimizer
...
...
@@ -51,10 +51,10 @@ class AnnealingLR(object):
self
.
wd_incr_style
=
wd_incr_style
self
.
override_
lr
_scheduler
=
override_
lr
_scheduler
self
.
use_checkpoint_
lr
_scheduler
=
use_checkpoint_
lr
_scheduler
if
self
.
override_
lr
_scheduler
:
assert
not
self
.
use_checkpoint_
lr
_scheduler
,
'both override and '
\
self
.
override_
opt_param
_scheduler
=
override_
opt_param
_scheduler
self
.
use_checkpoint_
opt_param
_scheduler
=
use_checkpoint_
opt_param
_scheduler
if
self
.
override_
opt_param
_scheduler
:
assert
not
self
.
use_checkpoint_
opt_param
_scheduler
,
'both override and '
\
'use-checkpoint are set.'
# Set the learning rate
...
...
@@ -147,11 +147,11 @@ class AnnealingLR(object):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if
self
.
override_
lr
_scheduler
:
if
self
.
override_
opt_param
_scheduler
:
print_rank_0
(
' > overriding {} value to {}'
.
format
(
name
,
cls_value
))
return
cls_value
if
not
self
.
use_checkpoint_
lr
_scheduler
:
if
not
self
.
use_checkpoint_
opt_param
_scheduler
:
assert
cls_value
==
sd_value
,
\
f
'AnnealingLR: class input value
{
cls_value
}
and checkpoint'
\
f
'value
{
sd_value
}
for
{
name
}
do not match'
...
...
megatron/training.py
View file @
04ecc834
...
...
@@ -43,7 +43,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.
learning_rates
import
AnnealingLR
from
megatron.
optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
...
...
@@ -118,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
...
...
@@ -149,7 +149,7 @@ def pretrain(train_valid_test_dataset_provider,
iteration
=
0
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr
_scheduler
,
model
,
optimizer
,
opt_param
_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
)
print_datetime
(
'after training is done'
)
...
...
@@ -162,7 +162,7 @@ def pretrain(train_valid_test_dataset_provider,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
if
args
.
do_test
:
# Run on test data.
...
...
@@ -304,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return
model
def
get_
learning_rate
_scheduler
(
optimizer
):
def
get_
optimizer_param
_scheduler
(
optimizer
):
"""Build the learning rate scheduler."""
args
=
get_args
()
...
...
@@ -334,7 +334,7 @@ def get_learning_rate_scheduler(optimizer):
raise
Exception
(
'either train-iters or train-samples should be provided.'
)
lr
_scheduler
=
AnnealingLR
(
opt_param
_scheduler
=
OptimizerParamScheduler
(
optimizer
,
max_lr
=
args
.
lr
,
min_lr
=
args
.
min_lr
,
...
...
@@ -344,10 +344,10 @@ def get_learning_rate_scheduler(optimizer):
start_wd
=
args
.
start_weight_decay
,
end_wd
=
args
.
end_weight_decay
,
wd_incr_style
=
args
.
weight_decay_incr_style
,
use_checkpoint_
lr
_scheduler
=
args
.
use_checkpoint_
lr
_scheduler
,
override_
lr
_scheduler
=
args
.
override_
lr
_scheduler
)
use_checkpoint_
opt_param
_scheduler
=
args
.
use_checkpoint_
opt_param
_scheduler
,
override_
opt_param
_scheduler
=
args
.
override_
opt_param
_scheduler
)
return
lr
_scheduler
return
opt_param
_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
...
...
@@ -365,7 +365,7 @@ def setup_model_and_optimizer(model_provider_func,
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
lr
_scheduler
=
get_
learning_rate
_scheduler
(
optimizer
)
opt_param
_scheduler
=
get_
optimizer_param
_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
timers
=
get_timers
()
...
...
@@ -373,7 +373,7 @@ def setup_model_and_optimizer(model_provider_func,
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr
_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
...
...
@@ -392,11 +392,11 @@ def setup_model_and_optimizer(model_provider_func,
if
args
.
fp16
:
optimizer
.
reload_model_params
()
return
model
,
optimizer
,
lr
_scheduler
return
model
,
optimizer
,
opt_param
_scheduler
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr
_scheduler
):
model
,
optimizer
,
opt_param
_scheduler
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -472,7 +472,7 @@ def train_step(forward_step_func, data_iterator,
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
lr
_scheduler
.
step
(
increment
=
increment
)
opt_param
_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
skipped_iter
=
1
...
...
@@ -662,19 +662,19 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return
report_memory_flag
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
):
timers
=
get_timers
()
# Extra barrier is added to make sure
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
stop
()
timers
.
log
([
'save-checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr
_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param
_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
):
"""Train the model function."""
...
...
@@ -704,7 +704,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
...
...
@@ -725,7 +725,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
...
@@ -742,14 +742,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler
=
get_signal_handler
()
if
any
(
signal_handler
.
signals_received
()):
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
print_datetime
(
'exiting program after receiving SIGTERM.'
)
sys
.
exit
()
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
saved_checkpoint
=
True
# Exiting based on duration
...
...
@@ -763,7 +763,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
done
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
sys
.
exit
()
...
...
@@ -771,7 +771,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
...
...
megatron/utils.py
View file @
04ecc834
...
...
@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
optimizer
,
opt_param
_scheduler
):
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
...
...
@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch
.
distributed
.
barrier
()
if
autoresume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
autoresume
.
request_resume
()
...
...
tasks/finetune_utils.py
View file @
04ecc834
...
...
@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return
train_dataloader
,
valid_dataloader
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
forward_step
,
def
_train
(
model
,
optimizer
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
):
"""Train the model."""
args
=
get_args
()
...
...
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
# Train for one step.
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
)
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
iteration
+=
1
...
...
@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
optimizer
,
opt_param
_scheduler
)
# Checkpointing
saved_checkpoint
=
False
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
saved_checkpoint
=
True
# Evaluation
...
...
@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
print_rank_0
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
...
...
@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model and optimizer'
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model.
if
args
.
epochs
>
0
:
_train
(
model
,
optimizer
,
lr
_scheduler
,
forward_step
,
_train
(
model
,
optimizer
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
)
# Or just evaluate.
else
:
...
...
tasks/vision/finetune_utils.py
View file @
04ecc834
...
...
@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
@@ -179,7 +179,7 @@ def _train(
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
iteration
+=
1
...
...
@@ -206,7 +206,7 @@ def _train(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Checkpointing
...
...
@@ -215,7 +215,7 @@ def _train(
and
args
.
save_interval
and
iteration
%
args
.
save_interval
==
0
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
...
@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
...
...
@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -300,7 +300,7 @@ def finetune(
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment