Commit 27e14f82 authored by Mohammad's avatar Mohammad
Browse files

refactored training

parent 3f58649b
...@@ -234,6 +234,8 @@ def _add_mixed_precision_args(parser): ...@@ -234,6 +234,8 @@ def _add_mixed_precision_args(parser):
'attention-softmax-in-fp32 to true') 'attention-softmax-in-fp32 to true')
group.add_argument('--attention-softmax-in-fp32', action='store_true', group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32.') help='Run attention masking and softmax in fp32.')
group.add_argument('--fp32-allreduce', action='store_true',
help='All-reduce in fp32')
group.add_argument('--hysteresis', type=int, default=2, group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling') help='hysteresis for dynamic loss scaling')
group.add_argument('--loss-scale', type=float, default=None, group.add_argument('--loss-scale', type=float, default=None,
......
...@@ -13,62 +13,57 @@ ...@@ -13,62 +13,57 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pretrain utilities""" """Pretrain utilities."""
from datetime import datetime from datetime import datetime
import math import math
import sys
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron.global_vars import get_args from megatron import get_args
from megatron.global_vars import get_timers from megatron import get_timers
from megatron.global_vars import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron.global_vars import get_adlr_autoresume
from megatron.initialize import initialize_megatron
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.checkpointing import load_checkpoint
from megatron import print_rank_0
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.checkpointing import save_checkpoint
def run(top_level_message, train_val_test_data_provider, def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
model_provider, forward_step_func, extra_args_provider=None, extra_args_provider=None, args_defaults={}):
args_defaults={}):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
1) get input arguments. 1) initialize Megatron.
2) initialize distributed and seeds. 2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets. 3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider. 4) train the modle using the forward_step_func.
5) train the modle using the forward_step_func.
Arguments: Arguments:
top_level_message: a meesage to print at the top of the run. train_val_test_data_provider: a function that builds datasets
train_val_test_data_provider: a function that takes `args` as input and returns `train, val, test` dataloaders.
and returns `train, val, test` dataloaders. Note that args are model_provider: a function that returns a vanilla version of the
passed and can be modified in case we need to use some parameters model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
later. For example, we can set vocab size using forward_step_func: a function that takes a `data iterator` and `model`,
args.vocab_size = ... and returns a `loss` scalar with a dictionary with key:values being
and later use this value in `model_provider`. the info we would like to monitor during training, for example
model_provider: a function that takes `args` and returns a vanilla `lm-loss: value`. We also require that this function add
version of the model. By vanilla we mean a simple model on cpu `batch generator` to the timers class.
with no fp16 or ddp. extra_args_provider: a function that takes a parser and adds arguments
forward_step_func: a function that takes a `data iterator`, `model`, to it. It is used for programs to add their own arguments.
`args`, and `timers` and returns a `loss` scalar with a dictionary args_defaults: a dictionary from argument-name to argument-value. It
with key:values being the info we would like to monitor during to set already parse arguments.
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
""" """
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
...@@ -76,36 +71,44 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -76,36 +71,44 @@ def run(top_level_message, train_val_test_data_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
writer = get_tensorboard_writer()
# Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args)
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, timers('model and optimizer').start()
args) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
# Data stuff.
timers('train/valid/test dataset').start()
train_data, val_data, test_data = train_val_test_data_provider()
timers('train/valid/test dataset').stop()
# Train, validation, and test data. # Train, validation, and test data.
timers('train/valid/test dataloader').start()
train_data_iterator, val_data_iterator, \ train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data, test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data, val_data,
test_data, test_data)
args) timers('train/valid/test dataloader').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['model and optimizer', 'train/valid/test dataset',
'train/valid/test dataloader'])
print_rank_0('training ...')
iteration = 0 iteration = 0
if args.train_iters > 0: if args.train_iters > 0:
if args.do_train: if args.do_train:
iteration, _ = train(forward_step_func, model, iteration, _ = train(forward_step_func,
optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, train_data_iterator, val_data_iterator)
timers, args, writer)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, val_data_iterator, model,
args, writer, iteration, iteration, False)
timers, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
...@@ -115,14 +118,15 @@ def run(top_level_message, train_val_test_data_provider, ...@@ -115,14 +118,15 @@ def run(top_level_message, train_val_test_data_provider,
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
args, None, 0, timers, True) 0, True)
def get_model(model_provider_func, args): def get_model(model_provider_func):
"""Build the model.""" """Build the model."""
args = get_args()
# Build model on cpu. # Build model on cpu.
model = model_provider_func(args) model = model_provider_func()
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
...@@ -140,26 +144,24 @@ def get_model(model_provider_func, args): ...@@ -140,26 +144,24 @@ def get_model(model_provider_func, args):
# Wrap model for distributed training.""" # Wrap model for distributed training."""
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
args.DDP_type = torchDDP model = torchDDP(model, device_ids=[i], output_device=i,
model = args.DDP_type(model, device_ids=[i], output_device=i, process_group=mpu.get_data_parallel_group())
process_group=mpu.get_data_parallel_group())
return model return model
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
args.DDP_type = LocalDDP model = LocalDDP(model)
model = args.DDP_type(model)
return model return model
print_rank_0('Unknown DDP implementation specified: {}. ' print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl)) 'Exiting.'.format(args.DDP_impl))
exit() sys.exit()
return model
def get_optimizer(model, args): def get_optimizer(model):
"""Set up the optimizer.""" """Set up the optimizer."""
args = get_args()
# Build parameter groups (weight decay and non-decay). # Build parameter groups (weight decay and non-decay).
while isinstance(model, (args.DDP_type, FP16_Module)): while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
model = model.module model = model.module
param_groups = get_params_for_weight_decay_optimization(model) param_groups = get_params_for_weight_decay_optimization(model)
...@@ -170,8 +172,7 @@ def get_optimizer(model, args): ...@@ -170,8 +172,7 @@ def get_optimizer(model, args):
param.model_parallel = False param.model_parallel = False
# Use Adam. # Use Adam.
optimizer = Adam(param_groups, optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer. # Wrap into fp16 optimizer.
if args.fp16: if args.fp16:
...@@ -186,8 +187,9 @@ def get_optimizer(model, args): ...@@ -186,8 +187,9 @@ def get_optimizer(model, args):
return optimizer return optimizer
def get_learning_rate_scheduler(optimizer, args): def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args()
# Add linear learning rate scheduler. # Add linear learning rate scheduler.
if args.lr_decay_iters is not None: if args.lr_decay_iters is not None:
...@@ -211,12 +213,13 @@ def get_learning_rate_scheduler(optimizer, args): ...@@ -211,12 +213,13 @@ def get_learning_rate_scheduler(optimizer, args):
return lr_scheduler return lr_scheduler
def setup_model_and_optimizer(model_provider_func, args): def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func, args) model = get_model(model_provider_func)
optimizer = get_optimizer(model, args) optimizer = get_optimizer(model)
lr_scheduler = get_learning_rate_scheduler(optimizer, args) lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None: if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
...@@ -226,8 +229,10 @@ def setup_model_and_optimizer(model_provider_func, args): ...@@ -226,8 +229,10 @@ def setup_model_and_optimizer(model_provider_func, args):
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss, args, timers): def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args()
timers = get_timers()
# Backward pass. # Backward pass.
optimizer.zero_grad() optimizer.zero_grad()
...@@ -255,18 +260,20 @@ def backward_step(optimizer, model, loss, args, timers): ...@@ -255,18 +260,20 @@ def backward_step(optimizer, model, loss, args, timers):
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, def train_step(forward_step_func, data_iterator,
args, timers): model, optimizer, lr_scheduler):
"""Single training step.""" """Single training step."""
args = get_args()
timers = get_timers()
# Forward model for one step. # Forward model for one step.
timers('forward').start() timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model, args, timers) loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop() timers('forward').stop()
# Calculate gradients, reduce across processes, and clip. # Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss, args, timers) backward_step(optimizer, model, loss)
timers('backward').stop() timers('backward').stop()
# Update parameters. # Update parameters.
...@@ -285,7 +292,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler, ...@@ -285,7 +292,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, writer, args, timers): loss_scale, report_memory_flag):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Update losses. # Update losses.
for key in loss_dict: for key in loss_dict:
...@@ -341,8 +352,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -341,8 +352,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer): train_data_iterator, val_data_iterator):
"""Train the model function.""" """Train the model function."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() model.train()
...@@ -361,8 +374,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -361,8 +374,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler, lr_scheduler)
args, timers)
skipped_iters += skipped_iter skipped_iters += skipped_iter
iteration += 1 iteration += 1
...@@ -370,8 +382,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -370,8 +382,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale, iteration, optimizer.loss_scale,
report_memory_flag, writer, args, report_memory_flag)
timers)
# Autoresume # Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \ if (iteration % args.adlr_autoresume_interval == 0) and \
...@@ -389,23 +400,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -389,23 +400,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args.do_valid: args.do_valid:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, args, val_data_iterator, model,
writer, iteration, timers, False) iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier() torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'. print_rank_0('rank: {} | time: {} | exiting the program at '
format(rank, time_str, iteration), flush=True) 'iteration {}'.format(rank, time_str, iteration))
exit() sys.exit()
return iteration, skipped_iters return iteration, skipped_iters
def evaluate(forward_step_func, data_iterator, model, def evaluate(forward_step_func, data_iterator, model, verbose=False):
args, timers, verbose=False):
"""Evaluation.""" """Evaluation."""
args = get_args()
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
...@@ -420,8 +431,7 @@ def evaluate(forward_step_func, data_iterator, model, ...@@ -420,8 +431,7 @@ def evaluate(forward_step_func, data_iterator, model,
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
# Forward evaluation. # Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model, _, loss_dict = forward_step_func(data_iterator, model)
args, timers)
# Reduce across processes. # Reduce across processes.
for key in loss_dict: for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
...@@ -437,11 +447,11 @@ def evaluate(forward_step_func, data_iterator, model, ...@@ -437,11 +447,11 @@ def evaluate(forward_step_func, data_iterator, model,
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
args, writer, iteration, iteration, verbose=False):
timers, verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(forward_step_func, data_iterator, model, writer = get_tensorboard_writer()
args, timers, verbose)
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict: for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
...@@ -459,8 +469,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -459,8 +469,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0('-' * length) print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data, args): def get_train_val_test_data_iterators(train_data, val_data, test_data):
"""Build train/validation/test iterators""" """Build train/validation/test iterators"""
args = get_args()
# Shift the start iterations. # Shift the start iterations.
if train_data is not None: if train_data is not None:
......
...@@ -18,24 +18,28 @@ ...@@ -18,24 +18,28 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.model import BertModel from megatron.model import BertModel
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import pretrain
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler from megatron.data_utils.samplers import DistributedBatchSampler
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel( model = BertModel(
num_layers=args.num_layers, num_layers=args.num_layers,
vocab_size=args.vocab_size, vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size, hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads, num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout, embedding_dropout_prob=args.hidden_dropout,
...@@ -46,7 +50,7 @@ def model_provider(args): ...@@ -46,7 +50,7 @@ def model_provider(args):
checkpoint_num_layers=args.checkpoint_num_layers, checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True, add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon, layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size, num_tokentypes=2,
parallel_output=True, parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32) attention_softmax_in_fp32=args.attention_softmax_in_fp32)
...@@ -54,19 +58,17 @@ def model_provider(args): ...@@ -54,19 +58,17 @@ def model_provider(args):
return model return model
def get_batch(data_iterator, timers): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
timers('data loader').start()
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
...@@ -80,13 +82,14 @@ def get_batch(data_iterator, timers): ...@@ -80,13 +82,14 @@ def get_batch(data_iterator, timers):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, args, timers): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator, timers) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
...@@ -108,9 +111,10 @@ def forward_step(data_iterator, model, args, timers): ...@@ -108,9 +111,10 @@ def forward_step(data_iterator, model, args, timers):
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]} return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(args): def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, valid_data, test_data) = (None, None, None) (train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
...@@ -202,6 +206,6 @@ if __name__ == "__main__": ...@@ -202,6 +206,6 @@ if __name__ == "__main__":
'tokenizer_type': 'BertWordPieceLowerCase'}) 'tokenizer_type': 'BertWordPieceLowerCase'})
exit() exit()
''' '''
run('Pretrain BERT model', get_train_val_test_data, pretrain(get_train_val_test_data,
model_provider, forward_step, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -17,6 +17,10 @@ ...@@ -17,6 +17,10 @@
import torch import torch
from megatron import get_args
from megatron import get_timers
from configure_data import configure_data from configure_data import configure_data
from gpt2_data_loader import make_gpt2_dataloaders from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu from megatron import mpu
...@@ -25,15 +29,16 @@ from megatron.utils import get_ltor_masks_and_position_ids ...@@ -25,15 +29,16 @@ from megatron.utils import get_ltor_masks_and_position_ids
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import pretrain
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers, model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size, vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size, hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads, num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout, embedding_dropout_prob=args.hidden_dropout,
...@@ -50,20 +55,19 @@ def model_provider(args): ...@@ -50,20 +55,19 @@ def model_provider(args):
return model return model
def get_batch(data_iterator, args, timers): def get_batch(data_iterator):
"""Generate a batch""" """Generate a batch"""
args = get_args()
# Items and their type. # Items and their type.
keys = ['text'] keys = ['text']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
timers('data loader').start()
if data_iterator is not None: if data_iterator is not None:
data = next(data_iterator) data = next(data_iterator)
else: else:
data = None data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
...@@ -85,8 +89,10 @@ def get_batch(data_iterator, args, timers): ...@@ -85,8 +89,10 @@ def get_batch(data_iterator, args, timers):
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model, args, timers): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args()
timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
...@@ -107,9 +113,10 @@ def forward_step(data_iterator, model, args, timers): ...@@ -107,9 +113,10 @@ def forward_step(data_iterator, model, args, timers):
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(args): def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None) (train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
...@@ -162,5 +169,5 @@ def get_train_val_test_data(args): ...@@ -162,5 +169,5 @@ def get_train_val_test_data(args):
if __name__ == "__main__": if __name__ == "__main__":
run('Pretrain GPT-2 model', get_train_val_test_data, pretrain(get_train_val_test_data,
model_provider, forward_step) model_provider, forward_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment