Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e724785f
Commit
e724785f
authored
Jan 28, 2022
by
Jared Casper
Browse files
Merge branch 'vision-merge' into 'main'
second phase of vision code merge See merge request ADLR/megatron-lm!381
parents
e156d2fe
2b628f96
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
536 additions
and
143 deletions
+536
-143
megatron/arguments.py
megatron/arguments.py
+18
-2
megatron/checkpointing.py
megatron/checkpointing.py
+9
-6
megatron/model/transformer.py
megatron/model/transformer.py
+69
-27
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+53
-20
megatron/optimizer_param_scheduler.py
megatron/optimizer_param_scheduler.py
+234
-0
megatron/schedules.py
megatron/schedules.py
+53
-25
megatron/training.py
megatron/training.py
+83
-46
megatron/utils.py
megatron/utils.py
+2
-2
tasks/finetune_utils.py
tasks/finetune_utils.py
+8
-8
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+7
-7
No files found.
megatron/arguments.py
View file @
e724785f
...
...
@@ -246,6 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
if
args
.
weight_decay_incr_style
==
'constant'
:
assert
args
.
start_weight_decay
is
None
assert
args
.
end_weight_decay
is
None
args
.
start_weight_decay
=
args
.
weight_decay
args
.
end_weight_decay
=
args
.
weight_decay
else
:
assert
args
.
start_weight_decay
is
not
None
assert
args
.
end_weight_decay
is
not
None
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
...
...
@@ -395,6 +404,13 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-weight-decay'
,
type
=
float
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-weight-decay'
,
type
=
float
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--weight-decay-incr-style'
,
type
=
str
,
default
=
'constant'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
...
...
@@ -561,13 +577,13 @@ def _add_learning_rate_args(parser):
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
group
.
add_argument
(
'--override-
lr
-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--override-
opt_param
-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.'
)
group
.
add_argument
(
'--use-checkpoint-
lr
-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-checkpoint-
opt_param
-scheduler'
,
action
=
'store_true'
,
help
=
'Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
...
...
megatron/checkpointing.py
View file @
e724785f
...
...
@@ -167,7 +167,7 @@ def get_rng_state():
return
rng_state_list
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
):
"""Save a model checkpoint."""
args
=
get_args
()
...
...
@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
lr
_scheduler
is
not
None
:
state_dict
[
'
lr
_scheduler'
]
=
lr
_scheduler
.
state_dict
()
if
opt_param
_scheduler
is
not
None
:
state_dict
[
'
opt_param
_scheduler'
]
=
opt_param
_scheduler
.
state_dict
()
# RNG states.
if
not
args
.
no_save_rng
:
...
...
@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0
(
" succesfully fixed query-key-values ordering for"
" checkpoint version {}"
.
format
(
checkpoint_version
))
def
load_checkpoint
(
model
,
optimizer
,
lr
_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
def
load_checkpoint
(
model
,
optimizer
,
opt_param
_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
...
...
@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
if
opt_param_scheduler
is
not
None
:
if
'lr_scheduler'
in
state_dict
:
# backward compatbility
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
else
:
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'opt_param_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
...
...
megatron/model/transformer.py
View file @
e724785f
...
...
@@ -42,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
class
ParallelMLP
(
MegatronModule
):
"""MLP.
...
...
@@ -406,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
...
@@ -434,6 +458,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
...
...
@@ -477,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
...
...
@@ -531,13 +562,19 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
...
...
@@ -548,7 +585,8 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -557,6 +595,7 @@ class ParallelTransformer(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
...
...
@@ -567,6 +606,8 @@ class ParallelTransformer(MegatronModule):
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
...
...
@@ -574,7 +615,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
...
...
megatron/optimizer/__init__.py
View file @
e724785f
...
...
@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
def
get_param_groups
(
modules
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
wd_no_scale_lr
=
[]
wd_scale_lr
=
[]
no_wd_no_scale_lr
=
[]
no_wd_scale_lr
=
[]
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
and
p
.
requires_grad
])
for
name
,
param
in
module
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
==
'bias'
])
# do not regularize biases nor Norm parameters
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
return
weight_decay_params
,
no_weight_decay_params
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_no_scale_lr
.
append
(
param
)
elif
not
no_wd
and
scale_lr
:
wd_scale_lr
.
append
(
param
)
elif
no_wd
and
not
scale_lr
:
no_wd_no_scale_lr
.
append
(
param
)
else
:
no_wd_scale_lr
.
append
(
param
)
def
get_megatron_optimizer
(
model
):
param_groups
=
[]
if
len
(
wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
wd_no_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
1.0
})
if
len
(
wd_scale_lr
):
param_groups
.
append
({
'params'
:
wd_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
lr_mult
})
if
len
(
no_wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_no_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
1.0
})
if
len
(
no_wd_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
lr_mult
})
return
param_groups
def
get_megatron_optimizer
(
model
,
no_weight_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
args
=
get_args
()
# Base optimizer.
param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
param_groups
=
get_param_groups
(
model
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
)
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
...
...
megatron/
learning_rates
.py
→
megatron/
optimizer_param_scheduler
.py
View file @
e724785f
...
...
@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate decay functions."""
"""Learning rate decay
and weight decay incr
functions."""
import
math
from
megatron
import
print_rank_0
class
AnnealingLR
(
object
):
"""Anneals
the
learning rate
.
"""
class
OptimizerParamScheduler
(
object
):
"""Anneals learning rate
and weight decay
"""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
lr_warmup_steps
,
lr_decay_steps
,
lr_decay_style
,
start_wd
,
end_wd
,
wd_incr_steps
,
wd_incr_style
,
use_checkpoint_opt_param_scheduler
=
True
,
override_opt_param_scheduler
=
False
):
# Class values.
self
.
optimizer
=
optimizer
...
...
@@ -35,24 +36,55 @@ class AnnealingLR(object):
assert
self
.
min_lr
>=
0.0
assert
self
.
max_lr
>=
self
.
min_lr
self
.
warmup_steps
=
warmup_steps
self
.
lr_
warmup_steps
=
lr_
warmup_steps
self
.
num_steps
=
0
self
.
decay_steps
=
decay_steps
assert
self
.
decay_steps
>
0
assert
self
.
warmup_steps
<
self
.
decay_steps
self
.
decay_style
=
decay_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
self
.
lr_decay_steps
=
lr_decay_steps
assert
self
.
lr_decay_steps
>
0
assert
self
.
lr_warmup_steps
<
self
.
lr_decay_steps
self
.
lr_decay_style
=
lr_decay_style
self
.
start_wd
=
start_wd
self
.
end_wd
=
end_wd
assert
self
.
start_wd
>=
0.0
assert
self
.
end_wd
>=
self
.
start_wd
self
.
wd_incr_steps
=
wd_incr_steps
self
.
wd_incr_style
=
wd_incr_style
self
.
override_opt_param_scheduler
=
override_opt_param_scheduler
self
.
use_checkpoint_opt_param_scheduler
=
use_checkpoint_opt_param_scheduler
if
self
.
override_opt_param_scheduler
:
assert
not
self
.
use_checkpoint_opt_param_scheduler
,
'both override and '
\
'use-checkpoint are set.'
# Set the learning rate
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
lr_decay_style
))
def
get_wd
(
self
):
""" Weight decay incr functions"""
if
self
.
num_steps
>
self
.
wd_incr_steps
:
return
self
.
end_wd
if
self
.
wd_incr_style
==
'constant'
:
assert
self
.
start_wd
==
self
.
end_wd
return
self
.
end_wd
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
incr_ratio
=
float
(
self
.
num_steps
)
/
float
(
self
.
wd_incr_steps
)
assert
incr_ratio
>=
0.0
assert
incr_ratio
<=
1.0
delta_wd
=
self
.
end_wd
-
self
.
start_wd
if
self
.
wd_incr_style
==
'linear'
:
coeff
=
incr_ratio
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
incr_ratio
))
+
1.0
)
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
return
self
.
start_wd
+
coeff
*
delta_wd
def
get_lr
(
self
):
...
...
@@ -60,33 +92,33 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
if
self
.
lr_
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
lr_
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
float
(
self
.
warmup_steps
)
float
(
self
.
lr_
warmup_steps
)
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
if
self
.
lr_
decay_style
==
'constant'
:
return
self
.
max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if
self
.
num_steps
>
self
.
decay_steps
:
# For any steps larger than `self.
lr_
decay_steps`, use `self.min_lr`.
if
self
.
num_steps
>
self
.
lr_
decay_steps
:
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
num_steps_
=
self
.
num_steps
-
self
.
warmup_steps
decay_steps_
=
self
.
decay_steps
-
self
.
warmup_steps
num_steps_
=
self
.
num_steps
-
self
.
lr_
warmup_steps
decay_steps_
=
self
.
lr_
decay_steps
-
self
.
lr_
warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
if
self
.
decay_style
==
'linear'
:
if
self
.
lr_
decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
elif
self
.
decay_style
==
'cosine'
:
elif
self
.
lr_
decay_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
decay_style
))
self
.
lr_
decay_style
))
return
self
.
min_lr
+
coeff
*
delta_lr
...
...
@@ -95,18 +127,24 @@ class AnnealingLR(object):
"""Set lr for all parameters groups."""
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
new_wd
=
self
.
get_wd
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
*
group
.
get
(
'lr_mult'
,
1.0
)
group
[
'weight_decay'
]
=
new_wd
*
group
.
get
(
'wd_mult'
,
1.0
)
def
state_dict
(
self
):
state_dict
=
{
'max_lr'
:
self
.
max_lr
,
'warmup_steps'
:
self
.
warmup_steps
,
'
lr_
warmup_steps'
:
self
.
lr_
warmup_steps
,
'num_steps'
:
self
.
num_steps
,
'decay_style'
:
self
.
decay_style
,
'decay_steps'
:
self
.
decay_steps
,
'min_lr'
:
self
.
min_lr
'lr_decay_style'
:
self
.
lr_decay_style
,
'lr_decay_steps'
:
self
.
lr_decay_steps
,
'min_lr'
:
self
.
min_lr
,
'start_wd'
:
self
.
start_wd
,
'end_wd'
:
self
.
end_wd
,
'wd_incr_style'
:
self
.
wd_incr_style
,
'wd_incr_steps'
:
self
.
wd_incr_steps
}
return
state_dict
...
...
@@ -114,13 +152,13 @@ class AnnealingLR(object):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if
self
.
override_
lr
_scheduler
:
if
self
.
override_
opt_param
_scheduler
:
print_rank_0
(
' > overriding {} value to {}'
.
format
(
name
,
cls_value
))
return
cls_value
if
not
self
.
use_checkpoint_
lr
_scheduler
:
if
not
self
.
use_checkpoint_
opt_param
_scheduler
:
assert
cls_value
==
sd_value
,
\
f
'
AnnealingLR
: class input value
{
cls_value
}
and checkpoint'
\
f
'
OptimizerParamScheduler
: class input value
{
cls_value
}
and checkpoint'
\
f
'value
{
sd_value
}
for
{
name
}
do not match'
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
...
...
@@ -140,25 +178,57 @@ class AnnealingLR(object):
'minimum learning rate'
)
if
'warmup_iter'
in
sd
:
warmup_steps_
=
sd
[
'warmup_iter'
]
lr_warmup_steps_
=
sd
[
'warmup_iter'
]
elif
'warmup_steps'
in
sd
:
lr_warmup_steps_
=
sd
[
'warmup_steps'
]
else
:
warmup_steps_
=
sd
[
'warmup_steps'
]
self
.
warmup_steps
=
self
.
_check_and_set
(
self
.
warmup_steps
,
warmup_steps_
,
lr_
warmup_steps_
=
sd
[
'
lr_
warmup_steps'
]
self
.
lr_
warmup_steps
=
self
.
_check_and_set
(
self
.
lr_
warmup_steps
,
lr_
warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
sd
:
decay_steps_
=
sd
[
'end_iter'
]
lr_decay_steps_
=
sd
[
'end_iter'
]
elif
'decay_steps'
in
sd
:
lr_decay_steps_
=
sd
[
'decay_steps'
]
else
:
decay_steps_
=
sd
[
'decay_steps'
]
self
.
decay_steps
=
self
.
_check_and_set
(
self
.
decay_steps
,
decay_steps_
,
lr_
decay_steps_
=
sd
[
'
lr_
decay_steps'
]
self
.
lr_
decay_steps
=
self
.
_check_and_set
(
self
.
lr_
decay_steps
,
lr_
decay_steps_
,
'total number of iterations'
)
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
'decay style'
)
if
'decay_style'
in
sd
:
lr_decay_style_
=
sd
[
'decay_style'
]
else
:
lr_decay_style_
=
sd
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
if
'num_iters'
in
sd
:
num_steps
=
sd
[
'num_iters'
]
else
:
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
increment
=
num_steps
)
if
'start_wd'
in
sd
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
sd
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
sd
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
sd
[
'wd_incr_steps'
],
"total number of weight decay iterations"
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
sd
[
'wd_incr_style'
],
"weight decay incr style"
)
megatron/schedules.py
View file @
e724785f
...
...
@@ -98,7 +98,12 @@ def custom_backward(output, grad_output):
)
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
=
False
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
...
...
@@ -120,10 +125,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
if
not
collect_non_loss_data
:
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
...
...
@@ -206,8 +216,12 @@ def dummy_handler():
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
...
...
@@ -219,35 +233,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
losses_reduced
=
[]
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
forward_data_store
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
...
...
@@ -307,7 +327,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
...
...
@@ -474,7 +496,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape
=
tensor_shape
,
timers
=
timers
))
return
losses_reduced
return
forward_data_store
def
get_tensor_shapes
(
rank
,
model_type
):
...
...
@@ -571,9 +593,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
...
...
@@ -608,13 +634,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
forward_data_store
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
...
...
@@ -633,7 +660,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
...
@@ -682,4 +710,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
return
losses_reduced
return
forward_data_store
megatron/training.py
View file @
e724785f
...
...
@@ -43,7 +43,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.
learning_rates
import
AnnealingLR
from
megatron.
optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
...
...
@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider
,
model_type
,
forward_step_func
,
process_non_loss_data_func
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
...
...
@@ -86,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
...
...
@@ -113,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
...
...
@@ -144,25 +149,28 @@ def pretrain(train_valid_test_dataset_provider,
iteration
=
0
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
if
args
.
do_test
:
# Run on test data.
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
0
,
True
)
0
,
process_non_loss_data_func
,
True
)
def
update_train_iters
(
args
):
...
...
@@ -296,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return
model
def
get_
learning_rate
_scheduler
(
optimizer
):
def
get_
optimizer_param
_scheduler
(
optimizer
):
"""Build the learning rate scheduler."""
args
=
get_args
()
...
...
@@ -304,11 +312,12 @@ def get_learning_rate_scheduler(optimizer):
if
args
.
train_iters
:
if
args
.
lr_decay_iters
is
None
:
args
.
lr_decay_iters
=
args
.
train_iters
decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
lr_decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
wd_incr_steps
=
args
.
train_iters
*
args
.
global_batch_size
if
args
.
lr_warmup_fraction
is
not
None
:
warmup_steps
=
args
.
lr_warmup_fraction
*
decay_steps
lr_
warmup_steps
=
args
.
lr_warmup_fraction
*
lr_
decay_steps
else
:
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
lr_
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
# Sample-based training.
elif
args
.
train_samples
:
# We need to set training iters for later use. Technically
...
...
@@ -317,29 +326,38 @@ def get_learning_rate_scheduler(optimizer):
update_train_iters
(
args
)
if
args
.
lr_decay_samples
is
None
:
args
.
lr_decay_samples
=
args
.
train_samples
decay_steps
=
args
.
lr_decay_samples
lr_decay_steps
=
args
.
lr_decay_samples
wd_incr_steps
=
args
.
train_samples
if
args
.
lr_warmup_fraction
is
not
None
:
warmup_steps
=
args
.
lr_warmup_fraction
*
decay_steps
lr_
warmup_steps
=
args
.
lr_warmup_fraction
*
lr_
decay_steps
else
:
warmup_steps
=
args
.
lr_warmup_samples
lr_
warmup_steps
=
args
.
lr_warmup_samples
else
:
raise
Exception
(
'either train-iters or train-samples should be provided.'
)
lr
_scheduler
=
AnnealingLR
(
opt_param
_scheduler
=
OptimizerParamScheduler
(
optimizer
,
max_lr
=
args
.
lr
,
min_lr
=
args
.
min_lr
,
warmup_steps
=
warmup_steps
,
decay_steps
=
decay_steps
,
decay_style
=
args
.
lr_decay_style
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
):
lr_warmup_steps
=
lr_warmup_steps
,
lr_decay_steps
=
lr_decay_steps
,
lr_decay_style
=
args
.
lr_decay_style
,
start_wd
=
args
.
start_weight_decay
,
end_wd
=
args
.
end_weight_decay
,
wd_incr_steps
=
wd_incr_steps
,
wd_incr_style
=
args
.
weight_decay_incr_style
,
use_checkpoint_opt_param_scheduler
=
args
.
use_checkpoint_opt_param_scheduler
,
override_opt_param_scheduler
=
args
.
override_opt_param_scheduler
)
return
opt_param_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
,
no_wd_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
"""Setup model and optimizer."""
args
=
get_args
()
...
...
@@ -347,9 +365,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
lr
_scheduler
=
get_
learning_rate
_scheduler
(
optimizer
)
opt_param
_scheduler
=
get_
optimizer_param
_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
timers
=
get_timers
()
...
...
@@ -357,7 +376,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr
_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
...
...
@@ -376,11 +395,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
if
args
.
fp16
:
optimizer
.
reload_model_params
()
return
model
,
optimizer
,
lr
_scheduler
return
model
,
optimizer
,
opt_param
_scheduler
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr
_scheduler
):
model
,
optimizer
,
opt_param
_scheduler
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -456,7 +475,7 @@ def train_step(forward_step_func, data_iterator,
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
lr
_scheduler
.
step
(
increment
=
increment
)
opt_param
_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
else
:
skipped_iter
=
1
...
...
@@ -646,20 +665,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return
report_memory_flag
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
):
timers
=
get_timers
()
# Extra barrier is added to make sure
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
stop
()
timers
.
log
([
'save-checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
):
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
):
"""Train the model function."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -687,7 +707,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
...
...
@@ -708,7 +728,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
...
@@ -716,7 +736,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
# Checkpointing
saved_checkpoint
=
False
...
...
@@ -724,14 +745,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler
=
get_signal_handler
()
if
any
(
signal_handler
.
signals_received
()):
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
print_datetime
(
'exiting program after receiving SIGTERM.'
)
sys
.
exit
()
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
saved_checkpoint
=
True
# Exiting based on duration
...
...
@@ -745,7 +766,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
done
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
sys
.
exit
()
...
...
@@ -753,7 +774,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
...
...
@@ -762,7 +783,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return
iteration
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
=
False
):
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
=
False
):
"""Evaluation."""
args
=
get_args
()
...
...
@@ -799,6 +824,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
collected_non_loss_data
=
None
if
process_non_loss_data_func
is
not
None
and
is_last_rank
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
for
model_module
in
model
:
model_module
.
train
()
...
...
@@ -806,16 +837,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
return
total_loss_dict
return
total_loss_dict
,
collected_non_loss_data
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
iteration
,
verbose
=
False
):
iteration
,
process_non_loss_data_func
,
verbose
=
False
):
"""Helper function to evaluate and dump results on screen."""
args
=
get_args
()
writer
=
get_tensorboard_writer
()
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
)
total_loss_dict
,
collected_non_loss_data
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
...
@@ -834,6 +868,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer
.
add_scalar
(
'{} validation ppl vs samples'
.
format
(
key
),
ppl
,
args
.
consumed_train_samples
)
if
process_non_loss_data_func
is
not
None
and
writer
and
is_last_rank
():
process_non_loss_data_func
(
collected_non_loss_data
,
iteration
,
writer
)
length
=
len
(
string
)
+
1
print_rank_last
(
'-'
*
length
)
print_rank_last
(
string
)
...
...
megatron/utils.py
View file @
e724785f
...
...
@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
optimizer
,
opt_param
_scheduler
):
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
...
...
@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch
.
distributed
.
barrier
()
if
autoresume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
autoresume
.
request_resume
()
...
...
tasks/finetune_utils.py
View file @
e724785f
...
...
@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return
train_dataloader
,
valid_dataloader
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
forward_step
,
def
_train
(
model
,
optimizer
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
):
"""Train the model."""
args
=
get_args
()
...
...
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
# Train for one step.
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
)
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
iteration
+=
1
...
...
@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
optimizer
,
opt_param
_scheduler
)
# Checkpointing
saved_checkpoint
=
False
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
saved_checkpoint
=
True
# Evaluation
...
...
@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
print_rank_0
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
...
...
@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model and optimizer'
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model.
if
args
.
epochs
>
0
:
_train
(
model
,
optimizer
,
lr
_scheduler
,
forward_step
,
_train
(
model
,
optimizer
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
)
# Or just evaluate.
else
:
...
...
tasks/vision/finetune_utils.py
View file @
e724785f
...
...
@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
@@ -179,7 +179,7 @@ def _train(
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
iteration
+=
1
...
...
@@ -206,7 +206,7 @@ def _train(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Checkpointing
...
...
@@ -215,7 +215,7 @@ def _train(
and
args
.
save_interval
and
iteration
%
args
.
save_interval
==
0
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
...
@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
...
...
@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -300,7 +300,7 @@ def finetune(
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment