Commit 1788c910 authored by Mohammad's avatar Mohammad
Browse files

both bert and gpt2 tested and working

parent 5f8623db
......@@ -312,9 +312,16 @@ def _add_data_args(parser):
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
parser.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
return parser
......@@ -340,13 +347,6 @@ def _add_gpt2_args(parser):
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='The filename containing all the shards '
'sizes for numpy data loader')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
return parser
......
......@@ -21,8 +21,10 @@ import torch
from megatron import get_args
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer
......@@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model,
sys.exit(0)
###################################################
def make_data_loader(dataset):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ltor_masks_and_position_ids(data,
......@@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data,
prev_index = i + 1
return attention_mask, loss_mask, position_ids
......@@ -23,14 +23,12 @@ from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses
def model_provider():
"""Build the model."""
args = get_args()
......@@ -151,26 +149,9 @@ def get_train_val_test_data():
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT datasets ...")
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=data_parallel_rank,
world_size=data_parallel_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(valid_ds)
test_data = make_data_loader_(test_ds)
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
......
......@@ -25,10 +25,10 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses
......@@ -121,32 +121,19 @@ def make_gpt2_dataloaders():
seq_length = args.seq_length
initial_seed = args.seed
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
def make_data_loader_(data_path):
# Build the dataset.
dataset = GPT2Dataset(data_path, input_data_sizes_file,
seq_length, initial_seed)
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
train = make_data_loader_(os.path.join(args.data_path, 'train'))
valid = make_data_loader_(os.path.join(args.data_path, 'valid'))
test = make_data_loader_(os.path.join(args.data_path, 'test'))
# Build the datasets.
def build_dataset_(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = build_dataset_('train')
valid_ds = build_dataset_('valid')
test_ds = build_dataset_('test')
# Dataloaders
train = make_data_loader(train_ds)
valid = make_data_loader(valid_ds)
test = make_data_loader(test_ds)
args.do_train = False
args.do_valid = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment