Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e724785f
Commit
e724785f
authored
Jan 28, 2022
by
Jared Casper
Browse files
Merge branch 'vision-merge' into 'main'
second phase of vision code merge See merge request ADLR/megatron-lm!381
parents
e156d2fe
2b628f96
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
536 additions
and
143 deletions
+536
-143
megatron/arguments.py
megatron/arguments.py
+18
-2
megatron/checkpointing.py
megatron/checkpointing.py
+9
-6
megatron/model/transformer.py
megatron/model/transformer.py
+69
-27
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+53
-20
megatron/optimizer_param_scheduler.py
megatron/optimizer_param_scheduler.py
+234
-0
megatron/schedules.py
megatron/schedules.py
+53
-25
megatron/training.py
megatron/training.py
+83
-46
megatron/utils.py
megatron/utils.py
+2
-2
tasks/finetune_utils.py
tasks/finetune_utils.py
+8
-8
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+7
-7
No files found.
megatron/arguments.py
View file @
e724785f
...
@@ -246,6 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -246,6 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
fp16
or
args
.
bf16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
if
args
.
weight_decay_incr_style
==
'constant'
:
assert
args
.
start_weight_decay
is
None
assert
args
.
end_weight_decay
is
None
args
.
start_weight_decay
=
args
.
weight_decay
args
.
end_weight_decay
=
args
.
weight_decay
else
:
assert
args
.
start_weight_decay
is
not
None
assert
args
.
end_weight_decay
is
not
None
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
# Persistent fused layer norm.
...
@@ -395,6 +404,13 @@ def _add_regularization_args(parser):
...
@@ -395,6 +404,13 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-weight-decay'
,
type
=
float
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-weight-decay'
,
type
=
float
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--weight-decay-incr-style'
,
type
=
str
,
default
=
'constant'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
...
@@ -561,13 +577,13 @@ def _add_learning_rate_args(parser):
...
@@ -561,13 +577,13 @@ def _add_learning_rate_args(parser):
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
'clip values below this threshold.'
)
group
.
add_argument
(
'--override-
lr
-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--override-
opt_param
-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.'
)
'that all the above values will be reset.'
)
group
.
add_argument
(
'--use-checkpoint-
lr
-scheduler'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-checkpoint-
opt_param
-scheduler'
,
action
=
'store_true'
,
help
=
'Use checkpoint to set the values of the scheduler '
help
=
'Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'rate, maximum number of iterations, and decay style '
...
...
megatron/checkpointing.py
View file @
e724785f
...
@@ -167,7 +167,7 @@ def get_rng_state():
...
@@ -167,7 +167,7 @@ def get_rng_state():
return
rng_state_list
return
rng_state_list
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
):
"""Save a model checkpoint."""
"""Save a model checkpoint."""
args
=
get_args
()
args
=
get_args
()
...
@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if
not
args
.
no_save_optim
:
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
lr
_scheduler
is
not
None
:
if
opt_param
_scheduler
is
not
None
:
state_dict
[
'
lr
_scheduler'
]
=
lr
_scheduler
.
state_dict
()
state_dict
[
'
opt_param
_scheduler'
]
=
opt_param
_scheduler
.
state_dict
()
# RNG states.
# RNG states.
if
not
args
.
no_save_rng
:
if
not
args
.
no_save_rng
:
...
@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
...
@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0
(
" succesfully fixed query-key-values ordering for"
print_rank_0
(
" succesfully fixed query-key-values ordering for"
" checkpoint version {}"
.
format
(
checkpoint_version
))
" checkpoint version {}"
.
format
(
checkpoint_version
))
def
load_checkpoint
(
model
,
optimizer
,
lr
_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
def
load_checkpoint
(
model
,
optimizer
,
opt_param
_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
"""Load a model checkpoint and return the iteration.
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
:attr:`state_dict` of the checkpoint match the names of
...
@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try
:
try
:
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
if
opt_param_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
if
'lr_scheduler'
in
state_dict
:
# backward compatbility
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
else
:
opt_param_scheduler
.
load_state_dict
(
state_dict
[
'opt_param_scheduler'
])
except
KeyError
:
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'Specify --no-load-optim or --finetune to prevent '
...
...
megatron/model/transformer.py
View file @
e724785f
...
@@ -42,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
...
@@ -42,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
hyperparameters: transformer hyperparameters
"""
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
class
ParallelMLP
(
MegatronModule
):
class
ParallelMLP
(
MegatronModule
):
"""MLP.
"""MLP.
...
@@ -406,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -406,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
@@ -434,6 +458,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -434,6 +458,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type
=
self_attn_mask_type
)
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
...
@@ -477,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -477,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
hidden_states
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
if
self
.
drop_path
is
None
:
# trigerring the fusion kernel. For now, we use two
# jit scripting for a nn.module (with dropout) is not
# different nn.functional routines to account for varying
# trigerring the fusion kernel. For now, we use two
# dropout semantics during training and inference phases.
# different nn.functional routines to account for varying
if
self
.
bias_dropout_fusion
:
# dropout semantics during training and inference phases.
if
self
.
training
:
if
self
.
bias_dropout_fusion
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
),
attention_bias
.
expand_as
(
residual
),
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
...
@@ -531,13 +562,19 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -531,13 +562,19 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
if
self
.
drop_path
is
None
:
with
torch
.
enable_grad
():
# re-enable torch grad to enable fused optimization.
output
=
bias_dropout_add_func
(
with
torch
.
enable_grad
():
mlp_output
,
output
=
bias_dropout_add_func
(
mlp_bias
.
expand_as
(
residual
),
mlp_output
,
residual
,
mlp_bias
.
expand_as
(
residual
),
self
.
hidden_dropout
)
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
return
output
...
@@ -548,7 +585,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -548,7 +585,8 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -557,6 +595,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -557,6 +595,7 @@ class ParallelTransformer(MegatronModule):
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
...
@@ -567,6 +606,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -567,6 +606,8 @@ class ParallelTransformer(MegatronModule):
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
...
@@ -574,7 +615,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -574,7 +615,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method
,
output_layer_init_method
,
layer_number
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'num_layers_per_stage must be divisible by '
\
...
...
megatron/optimizer/__init__.py
View file @
e724785f
...
@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
...
@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
def
get_param_groups
(
modules
,
"""Divide params into with-weight-decay and without-weight-decay groups.
no_weight_decay_cond
,
Layernorms and baises will have no weight decay but the rest will.
scale_lr_cond
,
lr_mult
):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
"""
wd_no_scale_lr
=
[]
weight_decay_params
=
{
'params'
:
[]}
wd_scale_lr
=
[]
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_wd_no_scale_lr
=
[]
no_wd_scale_lr
=
[]
for
module
in
modules
:
for
module
in
modules
:
for
module_
in
module
.
modules
():
for
name
,
param
in
module
.
named_parameters
():
if
isinstance
(
module_
,
LayerNorm
):
if
not
param
.
requires_grad
:
no_weight_decay_params
[
'params'
].
extend
(
continue
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
and
p
.
requires_grad
])
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
else
:
weight_decay_params
[
'params'
].
extend
(
# do not regularize biases nor Norm parameters
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
if
p
is
not
None
and
p
.
requires_grad
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_no_scale_lr
.
append
(
param
)
elif
not
no_wd
and
scale_lr
:
wd_scale_lr
.
append
(
param
)
elif
no_wd
and
not
scale_lr
:
no_wd_no_scale_lr
.
append
(
param
)
else
:
no_wd_scale_lr
.
append
(
param
)
def
get_megatron_optimizer
(
model
):
param_groups
=
[]
if
len
(
wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
wd_no_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
1.0
})
if
len
(
wd_scale_lr
):
param_groups
.
append
({
'params'
:
wd_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
lr_mult
})
if
len
(
no_wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_no_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
1.0
})
if
len
(
no_wd_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
lr_mult
})
return
param_groups
def
get_megatron_optimizer
(
model
,
no_weight_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
args
=
get_args
()
args
=
get_args
()
# Base optimizer.
# Base optimizer.
param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
param_groups
=
get_param_groups
(
model
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
)
if
args
.
optimizer
==
'adam'
:
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
lr
=
args
.
lr
,
...
...
megatron/
learning_rates
.py
→
megatron/
optimizer_param_scheduler
.py
View file @
e724785f
...
@@ -13,19 +13,20 @@
...
@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Learning rate decay functions."""
"""Learning rate decay
and weight decay incr
functions."""
import
math
import
math
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
class
AnnealingLR
(
object
):
class
OptimizerParamScheduler
(
object
):
"""Anneals
the
learning rate
.
"""
"""Anneals learning rate
and weight decay
"""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
lr_warmup_steps
,
lr_decay_steps
,
lr_decay_style
,
use_checkpoint_lr_scheduler
=
True
,
start_wd
,
end_wd
,
wd_incr_steps
,
wd_incr_style
,
override_lr_scheduler
=
False
):
use_checkpoint_opt_param_scheduler
=
True
,
override_opt_param_scheduler
=
False
):
# Class values.
# Class values.
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
...
@@ -35,24 +36,55 @@ class AnnealingLR(object):
...
@@ -35,24 +36,55 @@ class AnnealingLR(object):
assert
self
.
min_lr
>=
0.0
assert
self
.
min_lr
>=
0.0
assert
self
.
max_lr
>=
self
.
min_lr
assert
self
.
max_lr
>=
self
.
min_lr
self
.
warmup_steps
=
warmup_steps
self
.
lr_
warmup_steps
=
lr_
warmup_steps
self
.
num_steps
=
0
self
.
num_steps
=
0
self
.
decay_steps
=
decay_steps
self
.
lr_decay_steps
=
lr_decay_steps
assert
self
.
decay_steps
>
0
assert
self
.
lr_decay_steps
>
0
assert
self
.
warmup_steps
<
self
.
decay_steps
assert
self
.
lr_warmup_steps
<
self
.
lr_decay_steps
self
.
decay_style
=
decay_style
self
.
lr_decay_style
=
lr_decay_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
start_wd
=
start_wd
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
self
.
end_wd
=
end_wd
if
self
.
override_lr_scheduler
:
assert
self
.
start_wd
>=
0.0
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
assert
self
.
end_wd
>=
self
.
start_wd
self
.
wd_incr_steps
=
wd_incr_steps
self
.
wd_incr_style
=
wd_incr_style
self
.
override_opt_param_scheduler
=
override_opt_param_scheduler
self
.
use_checkpoint_opt_param_scheduler
=
use_checkpoint_opt_param_scheduler
if
self
.
override_opt_param_scheduler
:
assert
not
self
.
use_checkpoint_opt_param_scheduler
,
'both override and '
\
'use-checkpoint are set.'
'use-checkpoint are set.'
# Set the learning rate
# Set the learning rate
self
.
step
(
0
)
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
lr_decay_style
))
def
get_wd
(
self
):
""" Weight decay incr functions"""
if
self
.
num_steps
>
self
.
wd_incr_steps
:
return
self
.
end_wd
if
self
.
wd_incr_style
==
'constant'
:
assert
self
.
start_wd
==
self
.
end_wd
return
self
.
end_wd
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
incr_ratio
=
float
(
self
.
num_steps
)
/
float
(
self
.
wd_incr_steps
)
assert
incr_ratio
>=
0.0
assert
incr_ratio
<=
1.0
delta_wd
=
self
.
end_wd
-
self
.
start_wd
if
self
.
wd_incr_style
==
'linear'
:
coeff
=
incr_ratio
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
incr_ratio
))
+
1.0
)
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
return
self
.
start_wd
+
coeff
*
delta_wd
def
get_lr
(
self
):
def
get_lr
(
self
):
...
@@ -60,33 +92,33 @@ class AnnealingLR(object):
...
@@ -60,33 +92,33 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
if
self
.
lr_
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
lr_
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
float
(
self
.
warmup_steps
)
float
(
self
.
lr_
warmup_steps
)
# If the learning rate is constant, just return the initial value.
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
if
self
.
lr_
decay_style
==
'constant'
:
return
self
.
max_lr
return
self
.
max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
# For any steps larger than `self.
lr_
decay_steps`, use `self.min_lr`.
if
self
.
num_steps
>
self
.
decay_steps
:
if
self
.
num_steps
>
self
.
lr_
decay_steps
:
return
self
.
min_lr
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
# If we are done with the warmup period, use the decay style.
num_steps_
=
self
.
num_steps
-
self
.
warmup_steps
num_steps_
=
self
.
num_steps
-
self
.
lr_
warmup_steps
decay_steps_
=
self
.
decay_steps
-
self
.
warmup_steps
decay_steps_
=
self
.
lr_
decay_steps
-
self
.
lr_
warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
delta_lr
=
self
.
max_lr
-
self
.
min_lr
if
self
.
decay_style
==
'linear'
:
if
self
.
lr_
decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
coeff
=
(
1.0
-
decay_ratio
)
elif
self
.
decay_style
==
'cosine'
:
elif
self
.
lr_
decay_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
else
:
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
decay_style
))
self
.
lr_
decay_style
))
return
self
.
min_lr
+
coeff
*
delta_lr
return
self
.
min_lr
+
coeff
*
delta_lr
...
@@ -95,18 +127,24 @@ class AnnealingLR(object):
...
@@ -95,18 +127,24 @@ class AnnealingLR(object):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
self
.
num_steps
+=
increment
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
new_lr
=
self
.
get_lr
()
new_wd
=
self
.
get_wd
()
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
*
group
.
get
(
'lr_mult'
,
1.0
)
group
[
'weight_decay'
]
=
new_wd
*
group
.
get
(
'wd_mult'
,
1.0
)
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'max_lr'
:
self
.
max_lr
,
'max_lr'
:
self
.
max_lr
,
'warmup_steps'
:
self
.
warmup_steps
,
'
lr_
warmup_steps'
:
self
.
lr_
warmup_steps
,
'num_steps'
:
self
.
num_steps
,
'num_steps'
:
self
.
num_steps
,
'decay_style'
:
self
.
decay_style
,
'lr_decay_style'
:
self
.
lr_decay_style
,
'decay_steps'
:
self
.
decay_steps
,
'lr_decay_steps'
:
self
.
lr_decay_steps
,
'min_lr'
:
self
.
min_lr
'min_lr'
:
self
.
min_lr
,
'start_wd'
:
self
.
start_wd
,
'end_wd'
:
self
.
end_wd
,
'wd_incr_style'
:
self
.
wd_incr_style
,
'wd_incr_steps'
:
self
.
wd_incr_steps
}
}
return
state_dict
return
state_dict
...
@@ -114,13 +152,13 @@ class AnnealingLR(object):
...
@@ -114,13 +152,13 @@ class AnnealingLR(object):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them."""
if
self
.
override_
lr
_scheduler
:
if
self
.
override_
opt_param
_scheduler
:
print_rank_0
(
' > overriding {} value to {}'
.
format
(
name
,
cls_value
))
print_rank_0
(
' > overriding {} value to {}'
.
format
(
name
,
cls_value
))
return
cls_value
return
cls_value
if
not
self
.
use_checkpoint_
lr
_scheduler
:
if
not
self
.
use_checkpoint_
opt_param
_scheduler
:
assert
cls_value
==
sd_value
,
\
assert
cls_value
==
sd_value
,
\
f
'
AnnealingLR
: class input value
{
cls_value
}
and checkpoint'
\
f
'
OptimizerParamScheduler
: class input value
{
cls_value
}
and checkpoint'
\
f
'value
{
sd_value
}
for
{
name
}
do not match'
f
'value
{
sd_value
}
for
{
name
}
do not match'
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
name
))
...
@@ -140,25 +178,57 @@ class AnnealingLR(object):
...
@@ -140,25 +178,57 @@ class AnnealingLR(object):
'minimum learning rate'
)
'minimum learning rate'
)
if
'warmup_iter'
in
sd
:
if
'warmup_iter'
in
sd
:
warmup_steps_
=
sd
[
'warmup_iter'
]
lr_warmup_steps_
=
sd
[
'warmup_iter'
]
elif
'warmup_steps'
in
sd
:
lr_warmup_steps_
=
sd
[
'warmup_steps'
]
else
:
else
:
warmup_steps_
=
sd
[
'warmup_steps'
]
lr_
warmup_steps_
=
sd
[
'
lr_
warmup_steps'
]
self
.
warmup_steps
=
self
.
_check_and_set
(
self
.
warmup_steps
,
self
.
lr_
warmup_steps
=
self
.
_check_and_set
(
self
.
lr_
warmup_steps
,
warmup_steps_
,
lr_
warmup_steps_
,
'warmup iterations'
)
'warmup iterations'
)
if
'end_iter'
in
sd
:
if
'end_iter'
in
sd
:
decay_steps_
=
sd
[
'end_iter'
]
lr_decay_steps_
=
sd
[
'end_iter'
]
elif
'decay_steps'
in
sd
:
lr_decay_steps_
=
sd
[
'decay_steps'
]
else
:
else
:
decay_steps_
=
sd
[
'decay_steps'
]
lr_
decay_steps_
=
sd
[
'
lr_
decay_steps'
]
self
.
decay_steps
=
self
.
_check_and_set
(
self
.
decay_steps
,
decay_steps_
,
self
.
lr_
decay_steps
=
self
.
_check_and_set
(
self
.
lr_
decay_steps
,
lr_
decay_steps_
,
'total number of iterations'
)
'total number of iterations'
)
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
if
'decay_style'
in
sd
:
'decay style'
)
lr_decay_style_
=
sd
[
'decay_style'
]
else
:
lr_decay_style_
=
sd
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
if
'num_iters'
in
sd
:
if
'num_iters'
in
sd
:
num_steps
=
sd
[
'num_iters'
]
num_steps
=
sd
[
'num_iters'
]
else
:
else
:
num_steps
=
sd
[
'num_steps'
]
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
increment
=
num_steps
)
self
.
step
(
increment
=
num_steps
)
if
'start_wd'
in
sd
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
sd
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
sd
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
sd
[
'wd_incr_steps'
],
"total number of weight decay iterations"
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
sd
[
'wd_incr_style'
],
"weight decay incr style"
)
megatron/schedules.py
View file @
e724785f
...
@@ -98,7 +98,12 @@ def custom_backward(output, grad_output):
...
@@ -98,7 +98,12 @@ def custom_backward(output, grad_output):
)
)
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
=
False
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
If first stage, input tensor is obtained from data_iterator, otherwise
...
@@ -120,10 +125,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -120,10 +125,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model
.
set_input_tensor
(
input_tensor
)
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
if
not
collect_non_loss_data
:
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss_func
(
output_tensor
)
output_tensor
=
loss
/
get_num_microbatches
()
loss
,
loss_reduced
=
output_tensor
losses_reduced
.
append
(
loss_reduced
)
output_tensor
=
loss
/
get_num_microbatches
()
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
timers
(
'forward-compute'
).
stop
()
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
# If T5 model (or other model with encoder and decoder)
...
@@ -206,8 +216,12 @@ def dummy_handler():
...
@@ -206,8 +216,12 @@ def dummy_handler():
pass
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
def
forward_backward_no_pipelining
(
forward_step_func
,
optimizer
,
timers
,
forward_only
):
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run forward and backward passes with no pipeline parallelism
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
(no inter-stage communication).
...
@@ -219,35 +233,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
...
@@ -219,35 +233,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
context_handler
=
model
.
no_sync
losses_reduced
=
[]
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
with
context_handler
():
for
i
in
range
(
get_num_microbatches
()
-
1
):
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
input_tensor
,
losses_reduced
)
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
input_tensor
,
losses_reduced
)
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
optimizer
,
timers
,
forward_only
):
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
forward_data_store
=
[]
if
not
forward_only
:
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
...
@@ -307,7 +327,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -307,7 +327,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
=
forward_step
(
forward_step_func
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
@@ -474,7 +496,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -474,7 +496,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
timers
=
timers
))
timers
=
timers
))
return
losses_reduced
return
forward_data_store
def
get_tensor_shapes
(
rank
,
model_type
):
def
get_tensor_shapes
(
rank
,
model_type
):
...
@@ -571,9 +593,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
...
@@ -571,9 +593,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return
input_tensors
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
model
,
optimizer
,
timers
,
data_iterator
,
forward_only
):
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
stages.
...
@@ -608,13 +634,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -608,13 +634,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if
not
forward_only
:
if
not
forward_only
:
input_tensors
=
[]
input_tensors
=
[]
output_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
forward_data_store
=
[]
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
if
not
forward_only
:
...
@@ -633,7 +660,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -633,7 +660,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
@@ -682,4 +710,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -682,4 +710,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
return
losses_reduced
return
forward_data_store
megatron/training.py
View file @
e724785f
...
@@ -43,7 +43,7 @@ from megatron.model import ModelType
...
@@ -43,7 +43,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.
learning_rates
import
AnnealingLR
from
megatron.
optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
...
@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider
,
model_provider
,
model_type
,
model_type
,
forward_step_func
,
forward_step_func
,
process_non_loss_data_func
=
None
,
extra_args_provider
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
args_defaults
=
{}):
"""Main training program.
"""Main training program.
...
@@ -86,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -86,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
args_defaults: a dictionary from argument-name to argument-value. It
...
@@ -113,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -113,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
timers
(
'model-and-optimizer-setup'
).
start
()
timers
(
'model-and-optimizer-setup'
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
model_type
)
timers
(
'model-and-optimizer-setup'
).
stop
()
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
print_datetime
(
'after model, optimizer, and learning rate '
...
@@ -144,25 +149,28 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -144,25 +149,28 @@ def pretrain(train_valid_test_dataset_provider,
iteration
=
0
iteration
=
0
if
args
.
do_train
and
args
.
train_iters
>
0
:
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
)
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
)
print_datetime
(
'after training is done'
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
if
args
.
save
and
iteration
!=
0
:
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
if
args
.
do_test
:
if
args
.
do_test
:
# Run on test data.
# Run on test data.
prefix
=
'the end of training for test data'
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
test_data_iterator
,
model
,
0
,
True
)
0
,
process_non_loss_data_func
,
True
)
def
update_train_iters
(
args
):
def
update_train_iters
(
args
):
...
@@ -296,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -296,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return
model
return
model
def
get_
learning_rate
_scheduler
(
optimizer
):
def
get_
optimizer_param
_scheduler
(
optimizer
):
"""Build the learning rate scheduler."""
"""Build the learning rate scheduler."""
args
=
get_args
()
args
=
get_args
()
...
@@ -304,11 +312,12 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -304,11 +312,12 @@ def get_learning_rate_scheduler(optimizer):
if
args
.
train_iters
:
if
args
.
train_iters
:
if
args
.
lr_decay_iters
is
None
:
if
args
.
lr_decay_iters
is
None
:
args
.
lr_decay_iters
=
args
.
train_iters
args
.
lr_decay_iters
=
args
.
train_iters
decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
lr_decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
wd_incr_steps
=
args
.
train_iters
*
args
.
global_batch_size
if
args
.
lr_warmup_fraction
is
not
None
:
if
args
.
lr_warmup_fraction
is
not
None
:
warmup_steps
=
args
.
lr_warmup_fraction
*
decay_steps
lr_
warmup_steps
=
args
.
lr_warmup_fraction
*
lr_
decay_steps
else
:
else
:
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
lr_
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
# Sample-based training.
# Sample-based training.
elif
args
.
train_samples
:
elif
args
.
train_samples
:
# We need to set training iters for later use. Technically
# We need to set training iters for later use. Technically
...
@@ -317,29 +326,38 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -317,29 +326,38 @@ def get_learning_rate_scheduler(optimizer):
update_train_iters
(
args
)
update_train_iters
(
args
)
if
args
.
lr_decay_samples
is
None
:
if
args
.
lr_decay_samples
is
None
:
args
.
lr_decay_samples
=
args
.
train_samples
args
.
lr_decay_samples
=
args
.
train_samples
decay_steps
=
args
.
lr_decay_samples
lr_decay_steps
=
args
.
lr_decay_samples
wd_incr_steps
=
args
.
train_samples
if
args
.
lr_warmup_fraction
is
not
None
:
if
args
.
lr_warmup_fraction
is
not
None
:
warmup_steps
=
args
.
lr_warmup_fraction
*
decay_steps
lr_
warmup_steps
=
args
.
lr_warmup_fraction
*
lr_
decay_steps
else
:
else
:
warmup_steps
=
args
.
lr_warmup_samples
lr_
warmup_steps
=
args
.
lr_warmup_samples
else
:
else
:
raise
Exception
(
raise
Exception
(
'either train-iters or train-samples should be provided.'
)
'either train-iters or train-samples should be provided.'
)
lr
_scheduler
=
AnnealingLR
(
opt_param
_scheduler
=
OptimizerParamScheduler
(
optimizer
,
optimizer
,
max_lr
=
args
.
lr
,
max_lr
=
args
.
lr
,
min_lr
=
args
.
min_lr
,
min_lr
=
args
.
min_lr
,
warmup_steps
=
warmup_steps
,
lr_warmup_steps
=
lr_warmup_steps
,
decay_steps
=
decay_steps
,
lr_decay_steps
=
lr_decay_steps
,
decay_style
=
args
.
lr_decay_style
,
lr_decay_style
=
args
.
lr_decay_style
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
start_wd
=
args
.
start_weight_decay
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
end_wd
=
args
.
end_weight_decay
,
wd_incr_steps
=
wd_incr_steps
,
return
lr_scheduler
wd_incr_style
=
args
.
weight_decay_incr_style
,
use_checkpoint_opt_param_scheduler
=
args
.
use_checkpoint_opt_param_scheduler
,
override_opt_param_scheduler
=
args
.
override_opt_param_scheduler
)
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
):
return
opt_param_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
,
no_wd_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
"""Setup model and optimizer."""
"""Setup model and optimizer."""
args
=
get_args
()
args
=
get_args
()
...
@@ -347,9 +365,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -347,9 +365,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
lr
_scheduler
=
get_
learning_rate
_scheduler
(
optimizer
)
opt_param
_scheduler
=
get_
optimizer_param
_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -357,7 +376,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -357,7 +376,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
# max time.
# max time.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr
_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
timers
.
log
([
'load-checkpoint'
])
...
@@ -376,11 +395,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -376,11 +395,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
reload_model_params
()
optimizer
.
reload_model_params
()
return
model
,
optimizer
,
lr
_scheduler
return
model
,
optimizer
,
opt_param
_scheduler
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr
_scheduler
):
model
,
optimizer
,
opt_param
_scheduler
):
"""Single training step."""
"""Single training step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -456,7 +475,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -456,7 +475,7 @@ def train_step(forward_step_func, data_iterator,
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
args
.
data_parallel_size
args
.
data_parallel_size
lr
_scheduler
.
step
(
increment
=
increment
)
opt_param
_scheduler
.
step
(
increment
=
increment
)
skipped_iter
=
0
skipped_iter
=
0
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
...
@@ -646,20 +665,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -646,20 +665,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return
report_memory_flag
return
report_memory_flag
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
):
timers
=
get_timers
()
timers
=
get_timers
()
# Extra barrier is added to make sure
# Extra barrier is added to make sure
# all ranks report the max time.
# all ranks report the max time.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
start
()
timers
(
'save-checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'save-checkpoint'
).
stop
()
timers
(
'save-checkpoint'
).
stop
()
timers
.
log
([
'save-checkpoint'
])
timers
.
log
([
'save-checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
):
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
):
"""Train the model function."""
"""Train the model function."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -687,7 +707,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -687,7 +707,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
optimizer
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
iteration
+=
1
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
...
@@ -708,7 +728,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -708,7 +728,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
# Evaluation
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
@@ -716,7 +736,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -716,7 +736,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix
=
'iteration {}'
.
format
(
iteration
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
# Checkpointing
# Checkpointing
saved_checkpoint
=
False
saved_checkpoint
=
False
...
@@ -724,14 +745,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -724,14 +745,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler
=
get_signal_handler
()
signal_handler
=
get_signal_handler
()
if
any
(
signal_handler
.
signals_received
()):
if
any
(
signal_handler
.
signals_received
()):
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
print_datetime
(
'exiting program after receiving SIGTERM.'
)
print_datetime
(
'exiting program after receiving SIGTERM.'
)
sys
.
exit
()
sys
.
exit
()
if
args
.
save
and
args
.
save_interval
and
\
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
iteration
%
args
.
save_interval
==
0
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
saved_checkpoint
=
True
saved_checkpoint
=
True
# Exiting based on duration
# Exiting based on duration
...
@@ -745,7 +766,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -745,7 +766,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
done
:
if
done
:
if
not
saved_checkpoint
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
sys
.
exit
()
sys
.
exit
()
...
@@ -753,7 +774,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -753,7 +774,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
sys
.
exit
()
...
@@ -762,7 +783,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -762,7 +783,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return
iteration
return
iteration
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
=
False
):
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
=
False
):
"""Evaluation."""
"""Evaluation."""
args
=
get_args
()
args
=
get_args
()
...
@@ -799,6 +824,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -799,6 +824,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
*
get_num_microbatches
()
collected_non_loss_data
=
None
if
process_non_loss_data_func
is
not
None
and
is_last_rank
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
# Move model back to the train mode.
for
model_module
in
model
:
for
model_module
in
model
:
model_module
.
train
()
model_module
.
train
()
...
@@ -806,16 +837,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -806,16 +837,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
return
total_loss_dict
return
total_loss_dict
,
collected_non_loss_data
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
data_iterator
,
model
,
iteration
,
verbose
=
False
):
iteration
,
process_non_loss_data_func
,
verbose
=
False
):
"""Helper function to evaluate and dump results on screen."""
"""Helper function to evaluate and dump results on screen."""
args
=
get_args
()
args
=
get_args
()
writer
=
get_tensorboard_writer
()
writer
=
get_tensorboard_writer
()
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
)
total_loss_dict
,
collected_non_loss_data
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
@@ -834,6 +868,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
...
@@ -834,6 +868,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer
.
add_scalar
(
'{} validation ppl vs samples'
.
format
(
key
),
writer
.
add_scalar
(
'{} validation ppl vs samples'
.
format
(
key
),
ppl
,
args
.
consumed_train_samples
)
ppl
,
args
.
consumed_train_samples
)
if
process_non_loss_data_func
is
not
None
and
writer
and
is_last_rank
():
process_non_loss_data_func
(
collected_non_loss_data
,
iteration
,
writer
)
length
=
len
(
string
)
+
1
length
=
len
(
string
)
+
1
print_rank_last
(
'-'
*
length
)
print_rank_last
(
'-'
*
length
)
print_rank_last
(
string
)
print_rank_last
(
string
)
...
...
megatron/utils.py
View file @
e724785f
...
@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
...
@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
):
optimizer
,
opt_param
_scheduler
):
"""Check for autoresume signal and exit if it is received."""
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
...
@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
...
@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
autoresume
.
termination_requested
():
if
autoresume
.
termination_requested
():
if
args
.
save
:
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
print_rank_0
(
">>> autoresume termination request found!"
)
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
autoresume
.
request_resume
()
autoresume
.
request_resume
()
...
...
tasks/finetune_utils.py
View file @
e724785f
...
@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
...
@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return
train_dataloader
,
valid_dataloader
return
train_dataloader
,
valid_dataloader
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
forward_step
,
def
_train
(
model
,
optimizer
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
):
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
):
"""Train the model."""
"""Train the model."""
args
=
get_args
()
args
=
get_args
()
...
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration
=
0
start_iteration
=
0
# Train for one step.
# Train for one step.
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
)
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
out
iteration
+=
1
iteration
+=
1
...
@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if
args
.
adlr_autoresume
and
\
if
args
.
adlr_autoresume
and
\
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
optimizer
,
opt_param
_scheduler
)
# Checkpointing
# Checkpointing
saved_checkpoint
=
False
saved_checkpoint
=
False
if
args
.
save
and
args
.
save_interval
and
\
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
saved_checkpoint
=
True
saved_checkpoint
=
True
# Evaluation
# Evaluation
...
@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
if
not
saved_checkpoint
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print_rank_0
(
'exiting program at iteration {}'
.
format
(
iteration
))
print_rank_0
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
sys
.
exit
()
# Checkpointing at the end of each epoch.
# Checkpointing at the end of each epoch.
if
args
.
save
:
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
if
end_of_epoch_callback
is
not
None
:
...
@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
,
model_type
)
timers
(
'model and optimizer'
).
stop
()
timers
(
'model and optimizer'
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
...
@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model.
# Finetune the model.
if
args
.
epochs
>
0
:
if
args
.
epochs
>
0
:
_train
(
model
,
optimizer
,
lr
_scheduler
,
forward_step
,
_train
(
model
,
optimizer
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
)
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
)
# Or just evaluate.
# Or just evaluate.
else
:
else
:
...
...
tasks/vision/finetune_utils.py
View file @
e724785f
...
@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def
_train
(
def
_train
(
model
,
model
,
optimizer
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
forward_step
,
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
...
@@ -179,7 +179,7 @@ def _train(
...
@@ -179,7 +179,7 @@ def _train(
# Train for one step.
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
)
iteration
+=
1
iteration
+=
1
...
@@ -206,7 +206,7 @@ def _train(
...
@@ -206,7 +206,7 @@ def _train(
iteration
%
args
.
adlr_autoresume_interval
==
0
iteration
%
args
.
adlr_autoresume_interval
==
0
):
):
check_adlr_autoresume_termination
(
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
)
# Checkpointing
# Checkpointing
...
@@ -215,7 +215,7 @@ def _train(
...
@@ -215,7 +215,7 @@ def _train(
and
args
.
save_interval
and
args
.
save_interval
and
iteration
%
args
.
save_interval
==
0
and
iteration
%
args
.
save_interval
==
0
):
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Evaluation
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
@@ -231,7 +231,7 @@ def _train(
...
@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
# Checkpointing at the end of each epoch.
if
args
.
save
:
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
if
end_of_epoch_callback
is
not
None
:
...
@@ -266,7 +266,7 @@ def finetune(
...
@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
"model and optimizer"
).
stop
()
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
...
@@ -300,7 +300,7 @@ def finetune(
...
@@ -300,7 +300,7 @@ def finetune(
_train
(
_train
(
model
,
model
,
optimizer
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
forward_step
,
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment