Commit 3366a5b0 authored by Mohammad's avatar Mohammad
Browse files

refactored pretrain-bert

parent 27e14f82
......@@ -20,15 +20,15 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron.model import BertModel
from megatron import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import pretrain
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
def model_provider():
......@@ -114,7 +114,7 @@ def forward_step(data_iterator, model):
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
......@@ -176,36 +176,23 @@ def get_train_val_test_data():
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
num_tokens = vocab_size_with_padding(train_ds.num_tokens(), args)
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens
int(do_train),
int(do_valid),
int(do_test)])
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.vocab_size = token_counts[0].item()
args.tokentype_size = token_counts[1].item()
args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, valid_data, test_data
if __name__ == "__main__":
'''
from megatron.initialize import initialize_megatron
initialize_megatron(args_defaults={
'tokenizer_type': 'BertWordPieceLowerCase'})
exit()
'''
pretrain(get_train_val_test_data,
model_provider, forward_step,
pretrain(get_train_val_test_data, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment