Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
27e14f82
Commit
27e14f82
authored
Mar 28, 2020
by
Mohammad
Browse files
refactored training
parent
3f58649b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
139 additions
and
115 deletions
+139
-115
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/training.py
megatron/training.py
+101
-90
pretrain_bert.py
pretrain_bert.py
+18
-14
pretrain_gpt2.py
pretrain_gpt2.py
+18
-11
No files found.
megatron/arguments.py
View file @
27e14f82
...
@@ -234,6 +234,8 @@ def _add_mixed_precision_args(parser):
...
@@ -234,6 +234,8 @@ def _add_mixed_precision_args(parser):
'attention-softmax-in-fp32 to true'
)
'attention-softmax-in-fp32 to true'
)
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
help
=
'Run attention masking and softmax in fp32.'
)
help
=
'Run attention masking and softmax in fp32.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
help
=
'All-reduce in fp32'
)
group
.
add_argument
(
'--hysteresis'
,
type
=
int
,
default
=
2
,
group
.
add_argument
(
'--hysteresis'
,
type
=
int
,
default
=
2
,
help
=
'hysteresis for dynamic loss scaling'
)
help
=
'hysteresis for dynamic loss scaling'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
...
...
megatron/training.py
View file @
27e14f82
...
@@ -13,62 +13,57 @@
...
@@ -13,62 +13,57 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Pretrain utilities"""
"""Pretrain utilities
.
"""
from
datetime
import
datetime
from
datetime
import
datetime
import
math
import
math
import
sys
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron.global_vars
import
get_args
from
megatron
import
get_args
from
megatron.global_vars
import
get_timers
from
megatron
import
get_timers
from
megatron.global_vars
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron.global_vars
import
get_adlr_autoresume
from
megatron.initialize
import
initialize_megatron
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Optimizer
from
megatron.fp16
import
FP16_Optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.checkpointing
import
load_checkpoint
from
megatron
import
print_rank_0
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
megatron.checkpointing
import
save_checkpoint
def
run
(
top_level_message
,
train_val_test_data_provider
,
def
pretrain
(
train_val_test_data_provider
,
model_provider
,
forward_step_func
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
extra_args_provider
=
None
,
args_defaults
=
{}):
args_defaults
=
{}):
"""Main training program.
"""Main training program.
This function will run the followings in the order provided:
This function will run the followings in the order provided:
1)
get input arguments
.
1)
initialize Megatron
.
2)
initialize distributed and seeds
.
2)
setup model, optimizer and lr schedule using the model_provider
.
3) call train_val_test_data_provider to get train/val/test datasets.
3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider.
4) train the modle using the forward_step_func.
5) train the modle using the forward_step_func.
Arguments:
Arguments:
top_level_message: a meesage to print at the top of the run.
train_val_test_data_provider: a function that builds datasets
train_val_test_data_provider: a function that takes `args` as input
and returns `train, val, test` dataloaders.
and returns `train, val, test` dataloaders. Note that args are
model_provider: a function that returns a vanilla version of the
passed and can be modified in case we need to use some parameters
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
later. For example, we can set vocab size using
forward_step_func: a function that takes a `data iterator` and `model`,
args.vocab_size = ...
and returns a `loss` scalar with a dictionary with key:values being
and later use this value in `model_provider`.
the info we would like to monitor during training, for example
model_provider: a function that takes `args` and returns a vanilla
`lm-loss: value`. We also require that this function add
version of the model. By vanilla we mean a simple model on cpu
`batch generator` to the timers class.
with no fp16 or ddp.
extra_args_provider: a function that takes a parser and adds arguments
forward_step_func: a function that takes a `data iterator`, `model`,
to it. It is used for programs to add their own arguments.
`args`, and `timers` and returns a `loss` scalar with a dictionary
args_defaults: a dictionary from argument-name to argument-value. It
with key:values being the info we would like to monitor during
to set already parse arguments.
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
"""
"""
# Initalize and get arguments, timers, and Tensorboard writer.
# Initalize and get arguments, timers, and Tensorboard writer.
...
@@ -76,36 +71,44 @@ def run(top_level_message, train_val_test_data_provider,
...
@@ -76,36 +71,44 @@ def run(top_level_message, train_val_test_data_provider,
args_defaults
=
args_defaults
)
args_defaults
=
args_defaults
)
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
writer
=
get_tensorboard_writer
()
# Data stuff.
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
(
args
)
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
,
timers
(
'model and optimizer'
).
start
()
args
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
# Data stuff.
timers
(
'train/valid/test dataset'
).
start
()
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
()
timers
(
'train/valid/test dataset'
).
stop
()
# Train, validation, and test data.
# Train, validation, and test data.
timers
(
'train/valid/test dataloader'
).
start
()
train_data_iterator
,
val_data_iterator
,
\
train_data_iterator
,
val_data_iterator
,
\
test_data_iterator
=
get_train_val_test_data_iterators
(
train_data
,
test_data_iterator
=
get_train_val_test_data_iterators
(
train_data
,
val_data
,
val_data
,
test_data
,
test_data
)
args
)
timers
(
'train/valid/test dataloader'
).
stop
()
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
timers
.
log
([
'model and optimizer'
,
'train/valid/test dataset'
,
'train/valid/test dataloader'
])
print_rank_0
(
'training ...'
)
iteration
=
0
iteration
=
0
if
args
.
train_iters
>
0
:
if
args
.
train_iters
>
0
:
if
args
.
do_train
:
if
args
.
do_train
:
iteration
,
_
=
train
(
forward_step_func
,
model
,
iteration
,
_
=
train
(
forward_step_func
,
optimizer
,
lr_scheduler
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
,
train_data_iterator
,
val_data_iterator
)
timers
,
args
,
writer
)
if
args
.
do_valid
:
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
prefix
=
'the end of training for val data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
val_data_iterator
,
model
,
val_data_iterator
,
model
,
args
,
writer
,
iteration
,
iteration
,
False
)
timers
,
False
)
if
args
.
save
and
iteration
!=
0
:
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
@@ -115,14 +118,15 @@ def run(top_level_message, train_val_test_data_provider,
...
@@ -115,14 +118,15 @@ def run(top_level_message, train_val_test_data_provider,
prefix
=
'the end of training for test data'
prefix
=
'the end of training for test data'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
test_data_iterator
,
model
,
args
,
None
,
0
,
timers
,
True
)
0
,
True
)
def
get_model
(
model_provider_func
,
args
):
def
get_model
(
model_provider_func
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
# Build model on cpu.
# Build model on cpu.
model
=
model_provider_func
(
args
)
model
=
model_provider_func
()
# Print number of parameters.
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
@@ -140,26 +144,24 @@ def get_model(model_provider_func, args):
...
@@ -140,26 +144,24 @@ def get_model(model_provider_func, args):
# Wrap model for distributed training."""
# Wrap model for distributed training."""
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
args
.
DDP_type
=
torchDDP
model
=
torchDDP
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
model
=
args
.
DDP_type
(
model
,
device_ids
=
[
i
],
output_device
=
i
,
process_group
=
mpu
.
get_data_parallel_group
())
process_group
=
mpu
.
get_data_parallel_group
())
return
model
return
model
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
args
.
DDP_type
=
LocalDDP
model
=
LocalDDP
(
model
)
model
=
args
.
DDP_type
(
model
)
return
model
return
model
print_rank_0
(
'Unknown DDP implementation specified: {}. '
print_rank_0
(
'Unknown DDP implementation specified: {}. '
'Exiting.'
.
format
(
args
.
DDP_impl
))
'Exiting.'
.
format
(
args
.
DDP_impl
))
exit
()
sys
.
exit
()
return
model
def
get_optimizer
(
model
,
args
):
def
get_optimizer
(
model
):
"""Set up the optimizer."""
"""Set up the optimizer."""
args
=
get_args
()
# Build parameter groups (weight decay and non-decay).
# Build parameter groups (weight decay and non-decay).
while
isinstance
(
model
,
(
args
.
DDP_type
,
FP16_Module
)):
while
isinstance
(
model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
model
=
model
.
module
model
=
model
.
module
param_groups
=
get_params_for_weight_decay_optimization
(
model
)
param_groups
=
get_params_for_weight_decay_optimization
(
model
)
...
@@ -170,8 +172,7 @@ def get_optimizer(model, args):
...
@@ -170,8 +172,7 @@ def get_optimizer(model, args):
param
.
model_parallel
=
False
param
.
model_parallel
=
False
# Use Adam.
# Use Adam.
optimizer
=
Adam
(
param_groups
,
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
# Wrap into fp16 optimizer.
# Wrap into fp16 optimizer.
if
args
.
fp16
:
if
args
.
fp16
:
...
@@ -186,8 +187,9 @@ def get_optimizer(model, args):
...
@@ -186,8 +187,9 @@ def get_optimizer(model, args):
return
optimizer
return
optimizer
def
get_learning_rate_scheduler
(
optimizer
,
args
):
def
get_learning_rate_scheduler
(
optimizer
):
"""Build the learning rate scheduler."""
"""Build the learning rate scheduler."""
args
=
get_args
()
# Add linear learning rate scheduler.
# Add linear learning rate scheduler.
if
args
.
lr_decay_iters
is
not
None
:
if
args
.
lr_decay_iters
is
not
None
:
...
@@ -211,12 +213,13 @@ def get_learning_rate_scheduler(optimizer, args):
...
@@ -211,12 +213,13 @@ def get_learning_rate_scheduler(optimizer, args):
return
lr_scheduler
return
lr_scheduler
def
setup_model_and_optimizer
(
model_provider_func
,
args
):
def
setup_model_and_optimizer
(
model_provider_func
):
"""Setup model and optimizer."""
"""Setup model and optimizer."""
args
=
get_args
()
model
=
get_model
(
model_provider_func
,
args
)
model
=
get_model
(
model_provider_func
)
optimizer
=
get_optimizer
(
model
,
args
)
optimizer
=
get_optimizer
(
model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
,
args
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
...
@@ -226,8 +229,10 @@ def setup_model_and_optimizer(model_provider_func, args):
...
@@ -226,8 +229,10 @@ def setup_model_and_optimizer(model_provider_func, args):
return
model
,
optimizer
,
lr_scheduler
return
model
,
optimizer
,
lr_scheduler
def
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
):
def
backward_step
(
optimizer
,
model
,
loss
):
"""Backward step."""
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
# Backward pass.
# Backward pass.
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -255,18 +260,20 @@ def backward_step(optimizer, model, loss, args, timers):
...
@@ -255,18 +260,20 @@ def backward_step(optimizer, model, loss, args, timers):
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
,
def
train_step
(
forward_step_func
,
data_iterator
,
args
,
timers
):
model
,
optimizer
,
lr_scheduler
):
"""Single training step."""
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
# Forward model for one step.
# Forward model for one step.
timers
(
'forward'
).
start
()
timers
(
'forward'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
args
,
timers
)
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
)
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
# Calculate gradients, reduce across processes, and clip.
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
)
backward_step
(
optimizer
,
model
,
loss
)
timers
(
'backward'
).
stop
()
timers
(
'backward'
).
stop
()
# Update parameters.
# Update parameters.
...
@@ -285,7 +292,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
...
@@ -285,7 +292,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
loss_scale
,
report_memory_flag
,
writer
,
args
,
timers
):
loss_scale
,
report_memory_flag
):
"""Log training information such as losses, timing, ...."""
args
=
get_args
()
timers
=
get_timers
()
writer
=
get_tensorboard_writer
()
# Update losses.
# Update losses.
for
key
in
loss_dict
:
for
key
in
loss_dict
:
...
@@ -341,8 +352,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -341,8 +352,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
,
timers
,
args
,
writer
):
train_data_iterator
,
val_data_iterator
):
"""Train the model function."""
"""Train the model function."""
args
=
get_args
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
# Turn on training mode which enables dropout.
model
.
train
()
model
.
train
()
...
@@ -361,8 +374,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -361,8 +374,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
optimizer
,
optimizer
,
lr_scheduler
,
lr_scheduler
)
args
,
timers
)
skipped_iters
+=
skipped_iter
skipped_iters
+=
skipped_iter
iteration
+=
1
iteration
+=
1
...
@@ -370,8 +382,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -370,8 +382,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
optimizer
.
param_groups
[
0
][
'lr'
],
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
optimizer
.
loss_scale
,
iteration
,
optimizer
.
loss_scale
,
report_memory_flag
,
writer
,
args
,
report_memory_flag
)
timers
)
# Autoresume
# Autoresume
if
(
iteration
%
args
.
adlr_autoresume_interval
==
0
)
and
\
if
(
iteration
%
args
.
adlr_autoresume_interval
==
0
)
and
\
...
@@ -389,23 +400,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -389,23 +400,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args
.
do_valid
:
args
.
do_valid
:
prefix
=
'iteration {}'
.
format
(
iteration
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step_func
,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
val_data_iterator
,
model
,
args
,
val_data_iterator
,
model
,
writer
,
iteration
,
timers
,
False
)
iteration
,
False
)
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
print
(
'rank: {} | time: {} | exiting the program at
iteration {}'
.
print
_rank_0
(
'rank: {} | time: {} | exiting the program at
'
format
(
rank
,
time_str
,
iteration
)
,
flush
=
True
)
'iteration {}'
.
format
(
rank
,
time_str
,
iteration
))
exit
()
sys
.
exit
()
return
iteration
,
skipped_iters
return
iteration
,
skipped_iters
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
def
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
=
False
):
args
,
timers
,
verbose
=
False
):
"""Evaluation."""
"""Evaluation."""
args
=
get_args
()
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
model
.
eval
()
model
.
eval
()
...
@@ -420,8 +431,7 @@ def evaluate(forward_step_func, data_iterator, model,
...
@@ -420,8 +431,7 @@ def evaluate(forward_step_func, data_iterator, model,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
args
.
eval_iters
))
# Forward evaluation.
# Forward evaluation.
_
,
loss_dict
=
forward_step_func
(
data_iterator
,
model
,
_
,
loss_dict
=
forward_step_func
(
data_iterator
,
model
)
args
,
timers
)
# Reduce across processes.
# Reduce across processes.
for
key
in
loss_dict
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
...
@@ -437,11 +447,11 @@ def evaluate(forward_step_func, data_iterator, model,
...
@@ -437,11 +447,11 @@ def evaluate(forward_step_func, data_iterator, model,
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
data_iterator
,
model
,
args
,
writer
,
iteration
,
iteration
,
verbose
=
False
):
timers
,
verbose
=
False
):
"""Helper function to evaluate and dump results on screen."""
"""Helper function to evaluate and dump results on screen."""
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
writer
=
get_tensorboard_writer
()
args
,
timers
,
verbose
)
total_loss_dict
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
@@ -459,8 +469,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
...
@@ -459,8 +469,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0
(
'-'
*
length
)
print_rank_0
(
'-'
*
length
)
def
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
,
args
):
def
get_train_val_test_data_iterators
(
train_data
,
val_data
,
test_data
):
"""Build train/validation/test iterators"""
"""Build train/validation/test iterators"""
args
=
get_args
()
# Shift the start iterations.
# Shift the start iterations.
if
train_data
is
not
None
:
if
train_data
is
not
None
:
...
...
pretrain_bert.py
View file @
27e14f82
...
@@ -18,24 +18,28 @@
...
@@ -18,24 +18,28 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
ru
n
from
megatron.training
import
pretrai
n
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data_utils.samplers
import
DistributedBatchSampler
from
megatron.data_utils.samplers
import
DistributedBatchSampler
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building BERT model ...'
)
print_rank_0
(
'building BERT model ...'
)
model
=
BertModel
(
model
=
BertModel
(
num_layers
=
args
.
num_layers
,
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
padded_
vocab_size
,
hidden_size
=
args
.
hidden_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
...
@@ -46,7 +50,7 @@ def model_provider(args):
...
@@ -46,7 +50,7 @@ def model_provider(args):
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
checkpoint_num_layers
=
args
.
checkpoint_num_layers
,
add_binary_head
=
True
,
add_binary_head
=
True
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
num_tokentypes
=
args
.
tokentype_size
,
num_tokentypes
=
2
,
parallel_output
=
True
,
parallel_output
=
True
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
)
...
@@ -54,19 +58,17 @@ def model_provider(args):
...
@@ -54,19 +58,17 @@ def model_provider(args):
return
model
return
model
def
get_batch
(
data_iterator
,
timers
):
def
get_batch
(
data_iterator
):
# Items and their type.
# Items and their type.
keys
=
[
'text'
,
'types'
,
'labels'
,
'is_random'
,
'loss_mask'
,
'padding_mask'
]
keys
=
[
'text'
,
'types'
,
'labels'
,
'is_random'
,
'loss_mask'
,
'padding_mask'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
...
@@ -80,13 +82,14 @@ def get_batch(data_iterator, timers):
...
@@ -80,13 +82,14 @@ def get_batch(data_iterator, timers):
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
return
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
tokens
,
types
,
sentence_order
,
loss_mask
,
lm_labels
,
padding_mask
\
=
get_batch
(
data_iterator
,
timers
)
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
...
@@ -108,8 +111,9 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -108,8 +111,9 @@ def forward_step(data_iterator, model, args, timers):
return
loss
,
{
'lm loss'
:
reduced_losses
[
0
],
'sop loss'
:
reduced_losses
[
1
]}
return
loss
,
{
'lm loss'
:
reduced_losses
[
0
],
'sop loss'
:
reduced_losses
[
1
]}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
(
train_data
,
valid_data
,
test_data
)
=
(
None
,
None
,
None
)
...
@@ -202,6 +206,6 @@ if __name__ == "__main__":
...
@@ -202,6 +206,6 @@ if __name__ == "__main__":
'tokenizer_type': 'BertWordPieceLowerCase'})
'tokenizer_type': 'BertWordPieceLowerCase'})
exit()
exit()
'''
'''
run
(
'P
retrain
BERT model'
,
get_train_val_test_data
,
p
retrain
(
get_train_val_test_data
,
model_provider
,
forward_step
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
pretrain_gpt2.py
View file @
27e14f82
...
@@ -17,6 +17,10 @@
...
@@ -17,6 +17,10 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
configure_data
import
configure_data
from
configure_data
import
configure_data
from
gpt2_data_loader
import
make_gpt2_dataloaders
from
gpt2_data_loader
import
make_gpt2_dataloaders
from
megatron
import
mpu
from
megatron
import
mpu
...
@@ -25,15 +29,16 @@ from megatron.utils import get_ltor_masks_and_position_ids
...
@@ -25,15 +29,16 @@ from megatron.utils import get_ltor_masks_and_position_ids
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
ru
n
from
megatron.training
import
pretrai
n
def
model_provider
(
args
):
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_layers
=
args
.
num_layers
,
model
=
GPT2Model
(
num_layers
=
args
.
num_layers
,
vocab_size
=
args
.
vocab_size
,
vocab_size
=
args
.
padded_
vocab_size
,
hidden_size
=
args
.
hidden_size
,
hidden_size
=
args
.
hidden_size
,
num_attention_heads
=
args
.
num_attention_heads
,
num_attention_heads
=
args
.
num_attention_heads
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
embedding_dropout_prob
=
args
.
hidden_dropout
,
...
@@ -50,20 +55,19 @@ def model_provider(args):
...
@@ -50,20 +55,19 @@ def model_provider(args):
return
model
return
model
def
get_batch
(
data_iterator
,
args
,
timers
):
def
get_batch
(
data_iterator
):
"""Generate a batch"""
"""Generate a batch"""
args
=
get_args
()
# Items and their type.
# Items and their type.
keys
=
[
'text'
]
keys
=
[
'text'
]
datatype
=
torch
.
int64
datatype
=
torch
.
int64
# Broadcast data.
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
else
:
else
:
data
=
None
data
=
None
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
# Unpack.
...
@@ -85,8 +89,10 @@ def get_batch(data_iterator, args, timers):
...
@@ -85,8 +89,10 @@ def get_batch(data_iterator, args, timers):
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
).
start
()
...
@@ -107,8 +113,9 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -107,8 +113,9 @@ def forward_step(data_iterator, model, args, timers):
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
]}
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
]}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args
=
get_args
()
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
(
train_data
,
val_data
,
test_data
)
=
(
None
,
None
,
None
)
...
@@ -162,5 +169,5 @@ def get_train_val_test_data(args):
...
@@ -162,5 +169,5 @@ def get_train_val_test_data(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
run
(
'P
retrain
GPT-2 model'
,
get_train_val_test_data
,
p
retrain
(
get_train_val_test_data
,
model_provider
,
forward_step
)
model_provider
,
forward_step
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment