Commit 2e38461b authored by Mohammad's avatar Mohammad
Browse files

data loading for BERT and GPT cleaned up

parent eb74fa34
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 dataset."""
import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset
class GPT2Dataset(Dataset):
def __init__(self, data_path, sizes_filename, seq_length,
initial_seed, max_epochs=100):
# Input parameters.
self.data_path = data_path
self.sizes_filename = sizes_filename
self.seq_length = seq_length
self.initial_seed = initial_seed
self.max_epochs = max_epochs
# Shard stuff.
# Dictionary from shard nameto its size (number of element).
self.master_shard_size_dict = None
# Dictionary from shard name to modified size so it is
# divisible by self.seq_length.
self.shard_size_dict = None
# Long array (self.max_epochs * num-shards) populated
# randomly with shard names.
self.shards_name = None
# Start index of the data for a shard.
self.shards_start_index = None
self.build_shard_mappings_()
self.data_length = self.shards_start_index[-1]
# Data.
self.shards_data = [None]*self.shards_name.size
self.shards_sample_index = [None]*self.shards_name.size
def __len__(self):
return self.data_length
def __getitem__(self, idx):
# Find which shard we need.
shard_index = np.searchsorted(self.shards_start_index,
idx, side='right') - 1
# data index in the shard.
data_idx = idx - self.shards_start_index[shard_index]
# Load the shard if it is not in memory.
if self.shards_data[shard_index] is None:
print('global rank {} is building data for shard index {} ...'.
format(torch.distributed.get_rank(), shard_index))
self.build_dataset_(shard_index)
#assert self.shards_data[shard_index] is not None
# Start index.
start_index = self.shards_sample_index[shard_index][data_idx]
# Add one for label shift.
end_index = start_index + self.seq_length + 1
data = self.shards_data[shard_index][start_index:end_index]
return {'text': np.array(data, dtype=np.int64)}
def build_dataset_(self, shard_index):
# Garbage collect so we don't use a lot of memory.
# Leave the last one in case other threads have not catche up yet.
#for i in range(shard_index - 1):
for i in range(shard_index):
self.shards_data[i] = None
self.shards_sample_index[i] = None
# Read the shard.
filename = os.path.join(self.data_path, self.shards_name[shard_index])
print('loading {}'.format(filename))
data = np.load(filename, allow_pickle=True)
# Shuffle the data
rng = np.random.RandomState(self.initial_seed + shard_index)
rng.shuffle(data)
# Flatten.
data = np.hstack(data)
size = (data.shape[0] - 1) // self.seq_length
last_index = size * self.seq_length + 1
data = data[0:last_index]
self.shards_data[shard_index] = data
indices = np.arange(size) * self.seq_length
rng.shuffle(indices)
self.shards_sample_index[shard_index] = indices
def build_shard_mappings_(self):
# Load the sizes file.
sizes_filename = os.path.join(self.data_path, self.sizes_filename)
if torch.distributed.get_rank() == 0:
print(' > loading sizes from {}'.format(sizes_filename))
with open(sizes_filename, 'r') as f:
self.master_shard_size_dict = json.load(f)
if torch.distributed.get_rank() == 0:
print(' found {} shards'.format(len(self.master_shard_size_dict)))
# Adjust sizes to be a multiple of seq_length.
self.shard_size_dict = self.master_shard_size_dict.copy()
total_samples = 0
for shard in self.shard_size_dict:
size = self.shard_size_dict[shard]
size = ((size - 1) // self.seq_length) * self.seq_length
total_samples += size // self.seq_length
self.shard_size_dict[shard] = size
if torch.distributed.get_rank() == 0:
print(' found {} samples in the dataset'.format(total_samples))
# Build a list of shards.
shards_ = np.sort(np.array(list(self.shard_size_dict.keys())))
rng = np.random.RandomState(self.initial_seed)
self.shards_name = np.copy(shards_)
rng.shuffle(self.shards_name)
for i in range(1, self.max_epochs):
shards_c = np.copy(shards_)
rng.shuffle(shards_c)
self.shards_name = np.append(self.shards_name, shards_c)
# Build the global indexing.
self.shards_start_index = np.zeros(self.shards_name.size, dtype=np.int)
self.shards_start_index[0] = 0
for i in range(1, self.shards_name.size):
shard = str(self.shards_name[i-1])
size = self.shard_size_dict[shard]
self.shards_start_index[i] = self.shards_start_index[i-1] + \
size // self.seq_length
...@@ -37,11 +37,12 @@ from megatron.learning_rates import AnnealingLR ...@@ -37,11 +37,12 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
def pretrain(train_val_test_data_provider, model_provider, forward_step_func, def pretrain(train_valid_test_dataset_provider, model_provider,
extra_args_provider=None, args_defaults={}): forward_step_func, extra_args_provider=None, args_defaults={}):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -51,9 +52,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -51,9 +52,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
4) train the modle using the forward_step_func. 4) train the modle using the forward_step_func.
Arguments: Arguments:
train_val_test_data_provider: a function that builds datasets train_valid_test_dataset_provider: a function that takes the size of
and returns `train, val, test` dataloaders. train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp. model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator` and `model`, forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being and returns a `loss` scalar with a dictionary with key:values being
...@@ -78,22 +79,15 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -78,22 +79,15 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
timers('model and optimizer').stop() timers('model and optimizer').stop()
# Data stuff. # Data stuff.
timers('train/valid/test dataset').start() timers('train/valid/test data iterators').start()
train_data, val_data, test_data = train_val_test_data_provider() train_data_iterator, valid_data_iterator, test_data_iterator \
timers('train/valid/test dataset').stop() = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
# Train, validation, and test data. timers('train/valid/test data iterators').stop()
timers('train/valid/test dataloader').start()
train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data,
test_data)
timers('train/valid/test dataloader').stop()
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setups ...')
timers.log(['model and optimizer', 'train/valid/test dataset', timers.log(['model and optimizer', 'train/valid/test data iterators'])
'train/valid/test dataloader'])
print_rank_0('training ...') print_rank_0('training ...')
iteration = 0 iteration = 0
...@@ -101,13 +95,13 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -101,13 +95,13 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
if args.do_train: if args.do_train:
iteration, _ = train(forward_step_func, iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator) train_data_iterator, valid_data_iterator)
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
if args.save and iteration != 0: if args.save and iteration != 0:
...@@ -152,8 +146,7 @@ def get_model(model_provider_func): ...@@ -152,8 +146,7 @@ def get_model(model_provider_func):
return model return model
raise NotImplementedError('Unknown DDP implementation specified: {}. ' raise NotImplementedError('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl)) 'Exiting.'.format(args.DDP_impl))
sys.exit()
def get_optimizer(model): def get_optimizer(model):
...@@ -352,7 +345,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -352,7 +345,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator): train_data_iterator, valid_data_iterator):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -403,7 +396,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -403,7 +396,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args.do_valid: args.do_valid:
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
...@@ -472,37 +465,87 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -472,37 +465,87 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0('-' * length) print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data): def build_train_valid_test_data_iterators(
"""Build train/validation/test iterators""" build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args() args = get_args()
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
print_rank_0('> building train, validation, and test datasets ...')
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
# Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
# Build the datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
train_val_test_num_samples)
# Build dataloders.
train_dataloader = make_data_loader(train_ds)
valid_dataloader = make_data_loader(valid_ds)
test_dataloader = make_data_loader(test_ds)
# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0
do_valid = valid_dataloader is not None and args.eval_iters > 0
do_test = test_dataloader is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
# Shift the start iterations. # Shift the start iterations.
if train_data is not None: if train_dataloader is not None:
train_data.batch_sampler.start_iter = args.iteration % \ train_dataloader.batch_sampler.start_iter = args.iteration % \
len(train_data) len(train_dataloader)
print_rank_0('setting training data start iteration to {}'. print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter)) format(train_dataloader.batch_sampler.start_iter))
if val_data is not None: if valid_dataloader is not None:
start_iter_val = (args.iteration // args.eval_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \ valid_dataloader.batch_sampler.start_iter = start_iter_val % \
len(val_data) len(valid_dataloader)
print_rank_0('setting validation data start iteration to {}'. print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter)) format(valid_dataloader.batch_sampler.start_iter))
if train_data is not None: # Build iterators.
train_data_iterator = iter(train_data) if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else: else:
train_data_iterator = None train_data_iterator = None
if val_data is not None: if valid_dataloader is not None:
val_data_iterator = iter(val_data) valid_data_iterator = iter(valid_dataloader)
else: else:
val_data_iterator = None valid_data_iterator = None
if test_data is not None: if test_dataloader is not None:
test_data_iterator = iter(test_data) test_data_iterator = iter(test_dataloader)
else: else:
test_data_iterator = None test_data_iterator = None
return train_data_iterator, val_data_iterator, test_data_iterator return train_data_iterator, valid_data_iterator, test_data_iterator
...@@ -25,13 +25,11 @@ from megatron import print_rank_0 ...@@ -25,13 +25,11 @@ from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
...@@ -44,6 +42,7 @@ def model_provider(): ...@@ -44,6 +42,7 @@ def model_provider():
def get_batch(data_iterator): def get_batch(data_iterator):
"""Build the batch."""
# Items and their type. # Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'] keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
...@@ -96,70 +95,28 @@ def forward_step(data_iterator, model): ...@@ -96,70 +95,28 @@ def forward_step(data_iterator, model):
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]} return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Build train, valid, and test datasets."""
args = get_args() args = get_args()
(train_data, valid_data, test_data) = (None, None, None) print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
# Data loader only on rank 0 of each model parallel group. train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
if mpu.get_model_parallel_rank() == 0: data_prefix=args.data_path,
print_rank_0('> building train, validation, and test datasets ' data_impl=args.data_impl,
'for BERT ...') splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
data_parallel_size = mpu.get_data_parallel_world_size() max_seq_length=args.seq_length,
data_parallel_rank = mpu.get_data_parallel_rank() masked_lm_prob=args.mask_prob,
global_batch_size = args.batch_size * data_parallel_size short_seq_prob=args.short_seq_prob,
seed=args.seed,
# Number of train/valid/test samples. skip_warmup=(not args.mmap_warmup))
train_iters = args.train_iters print_rank_0("> finished creating BERT datasets ...")
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, valid_data, test_data return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
"""Pretrain GPT2""" """Pretrain GPT2"""
import os
import torch import torch
from megatron import get_args from megatron import get_args
...@@ -28,13 +26,11 @@ from megatron.data.gpt2_dataset import build_train_valid_test_datasets ...@@ -28,13 +26,11 @@ from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True) model = GPT2Model(num_tokentypes=0, parallel_output=True)
...@@ -98,68 +94,26 @@ def forward_step(data_iterator, model): ...@@ -98,68 +94,26 @@ def forward_step(data_iterator, model):
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Build train, valid, and test datasets."""
args = get_args() args = get_args()
(train_data, valid_data, test_data) = (None, None, None) print_rank_0('> building train, validation, and test datasets '
'for GPT2 ...')
# Data loader only on rank 0 of each model parallel group. train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
if mpu.get_model_parallel_rank() == 0: data_prefix=args.data_path,
print_rank_0('> building train, validation, and test datasets ' data_impl=args.data_impl,
'for GPT2 ...') splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
data_parallel_size = mpu.get_data_parallel_world_size() seq_length=args.seq_length,
data_parallel_rank = mpu.get_data_parallel_rank() seed=args.seed,
global_batch_size = args.batch_size * data_parallel_size skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT2 datasets ...")
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT2 datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, valid_data, test_data return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT2"""
import os
import torch
from megatron import get_args
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.data.old_gpt2_dataset import GPT2Dataset
from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch generator').stop()
# Forward model.
output = model(tokens, position_ids, attention_mask)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
reduced_loss = reduce_losses([loss])
return loss, {'lm loss': reduced_loss[0]}
def make_gpt2_dataloaders():
"""Build gpt2 dataloders."""
args = get_args()
# Input parameters.
input_data_sizes_file = args.input_data_sizes_file
seq_length = args.seq_length
initial_seed = args.seed
# Build the datasets.
def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = _build_dataset('train')
valid_ds = _build_dataset('valid')
test_ds = _build_dataset('test')
# Dataloaders
train = make_data_loader(train_ds)
valid = make_data_loader(valid_ds)
test = make_data_loader(test_ds)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
return (train, valid, test)
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
(train_data, val_data, test_data) = make_gpt2_dataloaders()
flags = torch.cuda.LongTensor([int(args.do_train),
int(args.do_valid),
int(args.do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, val_data, test_data
if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment