Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d8c85650
Commit
d8c85650
authored
Jan 21, 2022
by
Vijay Korthikanti
Browse files
second phase of vision code merge
parent
798b6a64
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
262 additions
and
85 deletions
+262
-85
megatron/arguments.py
megatron/arguments.py
+11
-0
megatron/learning_rates.py
megatron/learning_rates.py
+35
-2
megatron/model/transformer.py
megatron/model/transformer.py
+70
-27
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+52
-20
megatron/schedules.py
megatron/schedules.py
+53
-25
megatron/training.py
megatron/training.py
+41
-11
No files found.
megatron/arguments.py
View file @
d8c85650
...
@@ -246,6 +246,10 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -246,6 +246,10 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
fp16
or
args
.
bf16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
if
args
.
weight_decay
is
not
None
:
args
.
start_wd
=
args
.
weight_decay
args
.
end_wd
=
args
.
weight_decay
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
# Persistent fused layer norm.
# Persistent fused layer norm.
...
@@ -395,6 +399,13 @@ def _add_regularization_args(parser):
...
@@ -395,6 +399,13 @@ def _add_regularization_args(parser):
help
=
'Dropout probability for hidden state transformer.'
)
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--start-wd'
,
type
=
float
,
default
=
0.01
,
help
=
'Initial weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--end-wd'
,
type
=
float
,
default
=
0.01
,
help
=
'End of run weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--wd-incr-style'
,
type
=
str
,
default
=
'linear'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Weight decay increment function.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
...
...
megatron/learning_rates.py
View file @
d8c85650
...
@@ -24,6 +24,7 @@ class AnnealingLR(object):
...
@@ -24,6 +24,7 @@ class AnnealingLR(object):
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
warmup_steps
,
decay_steps
,
decay_style
,
start_wd
,
end_wd
,
wd_incr_style
,
use_checkpoint_lr_scheduler
=
True
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
override_lr_scheduler
=
False
):
...
@@ -43,6 +44,13 @@ class AnnealingLR(object):
...
@@ -43,6 +44,13 @@ class AnnealingLR(object):
self
.
decay_style
=
decay_style
self
.
decay_style
=
decay_style
self
.
start_wd
=
start_wd
self
.
end_wd
=
end_wd
assert
self
.
start_wd
>=
0.0
assert
self
.
end_wd
>=
self
.
start_wd
self
.
wd_incr_style
=
wd_incr_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
if
self
.
override_lr_scheduler
:
...
@@ -51,10 +59,33 @@ class AnnealingLR(object):
...
@@ -51,10 +59,33 @@ class AnnealingLR(object):
# Set the learning rate
# Set the learning rate
self
.
step
(
0
)
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_wd
(
self
):
if
self
.
num_steps
>
self
.
decay_steps
:
return
self
.
end_wd
if
self
.
wd_incr_style
==
'constant'
:
assert
self
.
start_wd
==
self
.
end_wd
return
self
.
end_wd
decay_ratio
=
float
(
self
.
num_steps
)
/
float
(
self
.
decay_steps
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_wd
=
self
.
end_wd
-
self
.
start_wd
if
self
.
wd_incr_style
==
'linear'
:
coeff
=
decay_ratio
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
decay_ratio
))
+
1.0
)
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
return
self
.
start_wd
+
coeff
*
delta_wd
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Learning rate decay functions from:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...
@@ -95,8 +126,10 @@ class AnnealingLR(object):
...
@@ -95,8 +126,10 @@ class AnnealingLR(object):
"""Set lr for all parameters groups."""
"""Set lr for all parameters groups."""
self
.
num_steps
+=
increment
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
new_lr
=
self
.
get_lr
()
new_wd
=
self
.
get_wd
()
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
*
group
[
'lr_mult'
]
group
[
'weight_decay'
]
=
new_wd
*
group
[
'wd_mult'
]
def
state_dict
(
self
):
def
state_dict
(
self
):
...
...
megatron/model/transformer.py
View file @
d8c85650
...
@@ -43,6 +43,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
...
@@ -43,6 +43,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
hyperparameters: transformer hyperparameters
"""
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
x
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
ParallelMLP
(
MegatronModule
):
class
ParallelMLP
(
MegatronModule
):
"""MLP.
"""MLP.
...
@@ -407,12 +430,14 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -407,12 +430,14 @@ class ParallelTransformerLayer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
layer_type
=
layer_type
self
.
drop_path_rate
=
drop_path_rate
self
.
apply_residual_connection_post_layernorm
\
self
.
apply_residual_connection_post_layernorm
\
=
args
.
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
...
@@ -435,6 +460,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -435,6 +460,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type
=
self_attn_mask_type
)
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
# Layernorm on the attention output
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
...
@@ -478,6 +504,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -478,6 +504,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
hidden_states
residual
=
hidden_states
if
self
.
drop_path_rate
==
0.0
:
# jit scripting for a nn.module (with dropout) is not
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# different nn.functional routines to account for varying
...
@@ -497,6 +524,11 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -497,6 +524,11 @@ class ParallelTransformerLayer(MegatronModule):
attention_bias
.
expand_as
(
residual
),
attention_bias
.
expand_as
(
residual
),
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
...
@@ -532,6 +564,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -532,6 +564,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
if
self
.
drop_path_rate
==
0.0
:
# re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
...
@@ -539,6 +572,11 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -539,6 +572,11 @@ class ParallelTransformerLayer(MegatronModule):
mlp_bias
.
expand_as
(
residual
),
mlp_bias
.
expand_as
(
residual
),
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
return
output
...
@@ -549,7 +587,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -549,7 +587,8 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
...
@@ -558,6 +597,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -558,6 +597,7 @@ class ParallelTransformer(MegatronModule):
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
...
@@ -568,6 +608,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -568,6 +608,8 @@ class ParallelTransformer(MegatronModule):
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
self
.
num_layers
)]
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
...
@@ -575,7 +617,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -575,7 +617,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method
,
output_layer_init_method
,
layer_number
,
layer_number
,
layer_type
=
layer_type
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
dpr
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'num_layers_per_stage must be divisible by '
\
...
...
megatron/optimizer/__init__.py
View file @
d8c85650
...
@@ -23,35 +23,67 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
...
@@ -23,35 +23,67 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
modules
):
def
get_param_groups
(
modules
,
"""Divide params into with-weight-decay and without-weight-decay groups.
no_weight_decay_cond
,
Layernorms and baises will have no weight decay but the rest will.
scale_lr_cond
,
lr_mult
):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
"""
wd_no_scale_lr
=
[]
weight_decay_params
=
{
'params'
:
[]}
wd_scale_lr
=
[]
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_wd_no_scale_lr
=
[]
no_wd_scale_lr
=
[]
for
module
in
modules
:
for
module
in
modules
:
for
module_
in
module
.
modules
():
for
name
,
param
in
module
.
named_parameters
():
if
isinstance
(
module_
,
LayerNorm
):
if
not
param
.
requires_grad
:
no_weight_decay_params
[
'params'
].
extend
(
continue
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
and
p
.
requires_grad
])
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
else
:
weight_decay_params
[
'params'
].
extend
(
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
p
.
requires_grad
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_no_scale_lr
.
append
(
param
)
elif
not
no_wd
and
scale_lr
:
wd_scale_lr
.
append
(
param
)
elif
no_wd
and
not
scale_lr
:
no_wd_no_scale_lr
.
append
(
param
)
else
:
no_wd_scale_lr
.
append
(
param
)
def
get_megatron_optimizer
(
model
):
param_groups
=
[]
if
len
(
wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
wd_no_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
1.0
})
if
len
(
wd_scale_lr
):
param_groups
.
append
({
'params'
:
wd_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
lr_mult
})
if
len
(
no_wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_no_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
1.0
})
if
len
(
no_wd_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
lr_mult
})
return
param_groups
def
get_megatron_optimizer
(
model
,
no_weight_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
args
=
get_args
()
args
=
get_args
()
# Base optimizer.
# Base optimizer.
param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
param_groups
=
get_param_groups
(
model
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
)
if
args
.
optimizer
==
'adam'
:
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
lr
=
args
.
lr
,
...
...
megatron/schedules.py
View file @
d8c85650
...
@@ -91,7 +91,12 @@ def custom_backward(output, grad_output):
...
@@ -91,7 +91,12 @@ def custom_backward(output, grad_output):
accumulate_grad
=
True
,
accumulate_grad
=
True
,
)
)
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
=
False
):
"""Forward step for passed-in model.
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
If first stage, input tensor is obtained from data_iterator, otherwise
...
@@ -113,10 +118,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
...
@@ -113,10 +118,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model
.
set_input_tensor
(
input_tensor
)
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
output_tensor
=
loss_func
(
output_tensor
)
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
get_num_microbatches
()
output_tensor
=
loss
/
get_num_microbatches
()
losses_reduced
.
append
(
loss_reduced
)
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
timers
(
'forward-compute'
).
stop
()
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
# If T5 model (or other model with encoder and decoder)
...
@@ -203,8 +213,12 @@ def dummy_handler():
...
@@ -203,8 +213,12 @@ def dummy_handler():
pass
pass
def
forward_backward_no_pipelining
(
forward_step_func
,
data_iterator
,
model
,
def
forward_backward_no_pipelining
(
forward_step_func
,
optimizer
,
timers
,
forward_only
):
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run forward and backward passes with no pipeline parallelism
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
(no inter-stage communication).
...
@@ -216,35 +230,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
...
@@ -216,35 +230,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
context_handler
=
model
.
no_sync
context_handler
=
model
.
no_sync
losses_reduced
=
[]
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
input_tensor
,
output_tensor_grad
=
None
,
None
with
context_handler
():
with
context_handler
():
for
i
in
range
(
get_num_microbatches
()
-
1
):
for
i
in
range
(
get_num_microbatches
()
-
1
):
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
input_tensor
,
losses_reduced
)
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
)
# Run computation for last microbatch out of context handler (want to
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
# synchronize gradients).
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
input_tensor
,
losses_reduced
)
model
,
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
not
forward_only
:
if
not
forward_only
:
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
return
losses_reduced
return
forward_data_store
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
data_iterator
,
model
,
def
forward_backward_pipelining_with_interleaving
(
forward_step_func
,
optimizer
,
timers
,
forward_only
):
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
losses_reduced
=
[]
forward_data_store
=
[]
if
not
forward_only
:
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
...
@@ -304,7 +324,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -304,7 +324,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor
=
forward_step
(
forward_step_func
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
model
[
model_chunk_id
],
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
# if forward-only, no need to save tensors for a backward pass
# if forward-only, no need to save tensors for a backward pass
...
@@ -471,7 +493,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -471,7 +493,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
timers
=
timers
))
timers
=
timers
))
return
losses_reduced
return
forward_data_store
def
get_tensor_shapes
(
rank
,
model_type
):
def
get_tensor_shapes
(
rank
,
model_type
):
...
@@ -568,9 +590,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
...
@@ -568,9 +590,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return
input_tensors
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
data_iterator
,
def
forward_backward_pipelining_without_interleaving
(
forward_step_func
,
model
,
optimizer
,
timers
,
data_iterator
,
forward_only
):
model
,
optimizer
,
timers
,
forward_only
,
collect_non_loss_data
=
False
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
stages.
...
@@ -605,13 +631,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -605,13 +631,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if
not
forward_only
:
if
not
forward_only
:
input_tensors
=
[]
input_tensors
=
[]
output_tensors
=
[]
output_tensors
=
[]
losses_reduced
=
[]
forward_data_store
=
[]
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
timers
=
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
if
not
forward_only
:
if
not
forward_only
:
...
@@ -630,7 +657,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -630,7 +657,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
last_iteration
=
(
i
==
(
num_microbatches_remaining
-
1
))
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
forward_data_store
,
collect_non_loss_data
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
timers
=
timers
)
...
@@ -679,4 +707,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
...
@@ -679,4 +707,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
timers
=
timers
)
return
losses_reduced
return
forward_data_store
megatron/training.py
View file @
d8c85650
...
@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider
,
model_provider
,
model_type
,
model_type
,
forward_step_func
,
forward_step_func
,
process_non_loss_data_func
=
None
,
extra_args_provider
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
args_defaults
=
{}):
"""Main training program.
"""Main training program.
...
@@ -145,14 +146,16 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -145,14 +146,16 @@ def pretrain(train_valid_test_dataset_provider,
if
args
.
do_train
and
args
.
train_iters
>
0
:
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
)
print_datetime
(
'after training is done'
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
if
args
.
save
and
iteration
!=
0
:
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
@@ -162,7 +165,8 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -162,7 +165,8 @@ def pretrain(train_valid_test_dataset_provider,
prefix
=
'the end of training for test data'
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
test_data_iterator
,
model
,
0
,
True
)
0
,
process_non_loss_data_func
,
True
)
def
update_train_iters
(
args
):
def
update_train_iters
(
args
):
...
@@ -333,13 +337,20 @@ def get_learning_rate_scheduler(optimizer):
...
@@ -333,13 +337,20 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps
=
warmup_steps
,
warmup_steps
=
warmup_steps
,
decay_steps
=
decay_steps
,
decay_steps
=
decay_steps
,
decay_style
=
args
.
lr_decay_style
,
decay_style
=
args
.
lr_decay_style
,
start_wd
=
args
.
start_wd
,
end_wd
=
args
.
end_wd
,
wd_incr_style
=
args
.
wd_incr_style
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
use_checkpoint_lr_scheduler
=
args
.
use_checkpoint_lr_scheduler
,
override_lr_scheduler
=
args
.
override_lr_scheduler
)
override_lr_scheduler
=
args
.
override_lr_scheduler
)
return
lr_scheduler
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
):
def
setup_model_and_optimizer
(
model_provider_func
,
model_type
,
no_wd_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
"""Setup model and optimizer."""
"""Setup model and optimizer."""
args
=
get_args
()
args
=
get_args
()
...
@@ -347,7 +358,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
...
@@ -347,7 +358,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
@@ -659,7 +671,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
...
@@ -659,7 +671,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
):
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
):
"""Train the model function."""
"""Train the model function."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -716,7 +729,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -716,7 +729,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix
=
'iteration {}'
.
format
(
iteration
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
process_non_loss_data_func
,
False
)
# Checkpointing
# Checkpointing
saved_checkpoint
=
False
saved_checkpoint
=
False
...
@@ -762,7 +776,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -762,7 +776,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return
iteration
return
iteration
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
=
False
):
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
=
False
):
"""Evaluation."""
"""Evaluation."""
args
=
get_args
()
args
=
get_args
()
...
@@ -799,6 +817,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -799,6 +817,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
*
get_num_microbatches
()
collected_non_loss_data
=
None
if
process_non_loss_data_func
is
not
None
and
is_last_rank
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
# Move model back to the train mode.
for
model_module
in
model
:
for
model_module
in
model
:
model_module
.
train
()
model_module
.
train
()
...
@@ -806,16 +830,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -806,16 +830,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get_num_microbatches
()
return
total_loss_dict
return
total_loss_dict
,
collected_non_loss_data
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
data_iterator
,
model
,
iteration
,
verbose
=
False
):
iteration
,
process_non_loss_data_func
,
verbose
=
False
):
"""Helper function to evaluate and dump results on screen."""
"""Helper function to evaluate and dump results on screen."""
args
=
get_args
()
args
=
get_args
()
writer
=
get_tensorboard_writer
()
writer
=
get_tensorboard_writer
()
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
)
total_loss_dict
,
collected_non_loss_data
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
@@ -834,6 +861,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
...
@@ -834,6 +861,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer
.
add_scalar
(
'{} validation ppl vs samples'
.
format
(
key
),
writer
.
add_scalar
(
'{} validation ppl vs samples'
.
format
(
key
),
ppl
,
args
.
consumed_train_samples
)
ppl
,
args
.
consumed_train_samples
)
if
process_non_loss_data_func
is
not
None
and
writer
and
is_last_rank
():
process_non_loss_data_func
(
collected_non_loss_data
,
iteration
,
writer
)
length
=
len
(
string
)
+
1
length
=
len
(
string
)
+
1
print_rank_last
(
'-'
*
length
)
print_rank_last
(
'-'
*
length
)
print_rank_last
(
string
)
print_rank_last
(
string
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment