Commit 1788c910 authored by Mohammad's avatar Mohammad
Browse files

both bert and gpt2 tested and working

parent 5f8623db
...@@ -312,9 +312,16 @@ def _add_data_args(parser): ...@@ -312,9 +312,16 @@ def _add_data_args(parser):
choices=['BertWordPieceLowerCase', choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'], 'GPT2BPETokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
parser.add_argument('--data-impl', type=str, default='infer', group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'], choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.') help='Implementation of indexed datasets.')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
return parser return parser
...@@ -340,13 +347,6 @@ def _add_gpt2_args(parser): ...@@ -340,13 +347,6 @@ def _add_gpt2_args(parser):
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
help='The filename containing all the shards ' help='The filename containing all the shards '
'sizes for numpy data loader') 'sizes for numpy data loader')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
return parser return parser
......
...@@ -21,8 +21,10 @@ import torch ...@@ -21,8 +21,10 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -87,7 +89,30 @@ def check_adlr_autoresume_termination(iteration, model,
sys.exit(0) sys.exit(0)
################################################### def make_data_loader(dataset):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ltor_masks_and_position_ids(data, def get_ltor_masks_and_position_ids(data,
...@@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data, ...@@ -145,4 +170,3 @@ def get_ltor_masks_and_position_ids(data,
prev_index = i + 1 prev_index = i + 1
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
...@@ -23,14 +23,12 @@ from megatron import get_timers ...@@ -23,14 +23,12 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import BertModel from megatron.model import BertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
...@@ -151,26 +149,9 @@ def get_train_val_test_data(): ...@@ -151,26 +149,9 @@ def get_train_val_test_data():
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT datasets ...") print_rank_0("> finished creating BERT datasets ...")
def make_data_loader_(dataset): train_data = make_data_loader(train_ds)
if not dataset: valid_data = make_data_loader(valid_ds)
return None test_data = make_data_loader(test_ds)
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=data_parallel_rank,
world_size=data_parallel_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(valid_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0 do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0 do_valid = valid_data is not None and args.eval_iters > 0
......
...@@ -25,10 +25,10 @@ from megatron import get_tokenizer ...@@ -25,10 +25,10 @@ from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data_utils.samplers import DistributedBatchSampler
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
...@@ -121,32 +121,19 @@ def make_gpt2_dataloaders(): ...@@ -121,32 +121,19 @@ def make_gpt2_dataloaders():
seq_length = args.seq_length seq_length = args.seq_length
initial_seed = args.seed initial_seed = args.seed
# Data parallel arguments. # Build the datasets.
world_size = mpu.get_data_parallel_world_size() def build_dataset_(name):
rank = mpu.get_data_parallel_rank() return GPT2Dataset(os.path.join(args.data_path, name),
global_batch_size = args.batch_size * world_size args.input_data_sizes_file,
num_workers = args.num_workers args.seq_length, args.seed)
train_ds = build_dataset_('train')
def make_data_loader_(data_path): valid_ds = build_dataset_('valid')
# Build the dataset. test_ds = build_dataset_('test')
dataset = GPT2Dataset(data_path, input_data_sizes_file,
seq_length, initial_seed) # Dataloaders
# Use a simple sampler with distributed batch sampler. train = make_data_loader(train_ds)
sampler = torch.utils.data.SequentialSampler(dataset) valid = make_data_loader(valid_ds)
batch_sampler = DistributedBatchSampler(sampler=sampler, test = make_data_loader(test_ds)
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
train = make_data_loader_(os.path.join(args.data_path, 'train'))
valid = make_data_loader_(os.path.join(args.data_path, 'valid'))
test = make_data_loader_(os.path.join(args.data_path, 'test'))
args.do_train = False args.do_train = False
args.do_valid = False args.do_valid = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment