Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d8c85650
Commit
d8c85650
authored
Jan 21, 2022
by
Vijay Korthikanti
Browse files
second phase of vision code merge
parent
798b6a64
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
262 additions
and
85 deletions
+262
-85
megatron/arguments.py
megatron/arguments.py
+11
-0
megatron/learning_rates.py
megatron/learning_rates.py
+35
-2
megatron/model/transformer.py
megatron/model/transformer.py
+70
-27
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+52
-20
megatron/schedules.py
megatron/schedules.py
+53
-25
megatron/training.py
megatron/training.py
+41
-11
No files found.
megatron/arguments.py
View file @
d8c85650
...
...
@@ -246,6 +246,10 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
if
args
.
weight_decay
is
not
None
:
args
.
start_wd
=
args
.
weight_decay
args
.
end_wd
=
args
.
weight_decay
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
...
...
@@ -395,6 +399,13 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-wd'
,
type
=
float
,
default
=
0.01
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-wd'
,
type
=
float
,
default
=
0.01
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--wd-incr-style'
,
type
=
str
,
default
=
'linear'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
...
...
megatron/learning_rates.py
View file @
d8c85650
...
...
@@ -24,6 +24,7 @@ class AnnealingLR(object):
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
start_wd
,
end_wd
,
wd_incr_style
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
...
...
@@ -43,6 +44,13 @@ class AnnealingLR(object):
self
.
decay_style
=
decay_style
self
.
start_wd
=
start_wd
self
.
end_wd
=
end_wd
assert
self
.
start_wd
>=
0.0
assert
self
.
end_wd
>=
self
.
start_wd
self
.
wd_incr_style
=
wd_incr_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
...
...
@@ -51,10 +59,33 @@ class AnnealingLR(object):
# Set the learning rate
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_wd
(
self
):
if
self
.
num_steps
>
self
.
decay_steps
:
return
self
.
end_wd
if
self
.
wd_incr_style
==
'constant'
:
assert
self
.
start_wd
==
self
.
end_wd
return
self
.
end_wd
decay_ratio
=
float
(
self
.
num_steps
)
/
float
(
self
.
decay_steps
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_wd
=
self
.
end_wd
-
self
.
start_wd
if
self
.
wd_incr_style
==
'linear'
:
coeff
=
decay_ratio
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
decay_ratio
))
+
1.0
)
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
return
self
.
start_wd
+
coeff
*
delta_wd
def
get_lr
(
self
):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...
...
@@ -95,8 +126,10 @@ class AnnealingLR(object):
"""Set lr for all parameters groups."""
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
new_wd
=
self
.
get_wd
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
*
group
[
'lr_mult'
]
group
[
'weight_decay'
]
=
new_wd
*
group
[
'wd_mult'
]
def
state_dict
(
self
):
...
...
megatron/model/transformer.py
View file @
d8c85650
...
...
@@ -43,6 +43,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
x
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
ParallelMLP
(
MegatronModule
):
"""MLP.
...
...
@@ -407,12 +430,14 @@ class ParallelTransformerLayer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
drop_path_rate
=
drop_path_rate
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
...
...
@@ -435,6 +460,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
...
...
@@ -478,25 +504,31 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
if
self
.
drop_path_rate
==
0.0
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
...
...
@@ -532,13 +564,19 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
if
self
.
drop_path_rate
==
0.0
:
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
...
...
@@ -549,7 +587,8 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -558,6 +597,7 @@ class ParallelTransformer(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
...
...
@@ -568,6 +608,8 @@ class ParallelTransformer(MegatronModule):
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
self
.
num_layers
)]
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
...
...
@@ -575,7 +617,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
dpr
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
...
...
megatron/optimizer/__init__.py
View file @
d8c85650
...
...
@@ -23,35 +23,67 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
def
get_param_groups
(
modules
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
wd_no_scale_lr
=
[]
wd_scale_lr
=
[]
no_wd_no_scale_lr
=
[]
no_wd_scale_lr
=
[]
for
module
in
modules
:
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
and
p
.
requires_grad
])
for
name
,
param
in
module
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
==
'bias'
])
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
return
weight_decay_params
,
no_weight_decay_params
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_no_scale_lr
.
append
(
param
)
elif
not
no_wd
and
scale_lr
:
wd_scale_lr
.
append
(
param
)
elif
no_wd
and
not
scale_lr
:
no_wd_no_scale_lr
.
append
(
param
)
else
:
no_wd_scale_lr
.
append
(
param
)
def
get_megatron_optimizer
(
model
):
param_groups
=
[]
if
len
(
wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
wd_no_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
1.0
})
if
len
(
wd_scale_lr
):
param_groups
.
append
({
'params'
:
wd_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
lr_mult
})
if
len
(
no_wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_no_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
1.0
})
if
len
(
no_wd_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
lr_mult
})
return
param_groups
def
get_megatron_optimizer
(
model
,
no_weight_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
args
=
get_args
()
# Base optimizer.
param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
param_groups
=
get_param_groups
(
model
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
)
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
...
...
megatron/schedules.py
View file @
d8c85650
...
...
@@ -91,7 +91,12 @@ def custom_backward(output, grad_output):
accumulate_grad
=
True
,
)
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
=
False
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
...
...
@@ -113,10 +118,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
if
not
collect_non_loss_data
:
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
...
...
@@ -203,8 +213,12 @@ def dummy_handler():
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
...
...
@@ -216,35 +230,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
losses_reduced
=
[]
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
forward_data_store
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
...
...
@@ -304,7 +324,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
...
...
@@ -471,7 +493,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape
=
tensor_shape
,
timers
=
timers
))
return
losses_reduced
return
forward_data_store
def
get_tensor_shapes
(
rank
,
model_type
):
...
...
@@ -568,9 +590,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
):
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
...
...
@@ -605,13 +631,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
forward_data_store
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
...
...
@@ -630,7 +657,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
...
@@ -679,4 +707,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
return
losses_reduced
return
forward_data_store
megatron/training.py
View file @
d8c85650
...
...
@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider
,
model_type
,
forward_step_func
,
process_non_loss_data_func
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
...
...
@@ -145,14 +146,16 @@ def pretrain(train_valid_test_dataset_provider,
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
@@ -162,7 +165,8 @@ def pretrain(train_valid_test_dataset_provider,
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
0
,
True
)
0
,
process_non_loss_data_func
,
True
)
def
update_train_iters
(
args
):
...
...
@@ -333,13 +337,20 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps
=
warmup_steps
,
decay_steps
=
decay_steps
,
decay_style
=
args
.
lr_decay_style
,
start_wd
=
args
.
start_wd
,
end_wd
=
args
.
end_wd
,
wd_incr_style
=
args
.
wd_incr_style
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
):
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
,
no_wd_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
"""Setup model and optimizer."""
args
=
get_args
()
...
...
@@ -347,7 +358,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
...
@@ -659,7 +671,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
):
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
):
"""Train the model function."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -716,7 +729,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
# Checkpointing
saved_checkpoint
=
False
...
...
@@ -762,7 +776,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return
iteration
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
=
False
):
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
=
False
):
"""Evaluation."""
args
=
get_args
()
...
...
@@ -799,6 +817,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
collected_non_loss_data
=
None
if
process_non_loss_data_func
is
not
None
and
is_last_rank
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
for
model_module
in
model
:
model_module
.
train
()
...
...
@@ -806,16 +830,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
return
total_loss_dict
return
total_loss_dict
,
collected_non_loss_data
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
iteration
,
verbose
=
False
):
iteration
,
process_non_loss_data_func
,
verbose
=
False
):
"""Helper function to evaluate and dump results on screen."""
args
=
get_args
()
writer
=
get_tensorboard_writer
()
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
)
total_loss_dict
,
collected_non_loss_data
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
...
@@ -834,6 +861,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer
.
add_scalar
(
'{} validation ppl vs samples'
.
format
(
key
),
ppl
,
args
.
consumed_train_samples
)
if
process_non_loss_data_func
is
not
None
and
writer
and
is_last_rank
():
process_non_loss_data_func
(
collected_non_loss_data
,
iteration
,
writer
)
length
=
len
(
string
)
+
1
print_rank_last
(
'-'
*
length
)
print_rank_last
(
string
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment