Commit 4558e42f authored by Neel Kant's avatar Neel Kant
Browse files

Implement InverseClozeDataset with IndexedDataset

parent 9350ee08
...@@ -27,13 +27,15 @@ from megatron import mpu ...@@ -27,13 +27,15 @@ from megatron import mpu
from megatron.data import helpers from megatron.data import helpers
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.ict_dataset import InverseClozeDataset
from megatron import print_rank_0 from megatron import print_rank_0
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup): short_seq_prob, seed, skip_warmup,
ict_dataset=False):
# Indexed dataset. # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, indexed_dataset = get_indexed_dataset_(data_prefix,
...@@ -74,16 +76,21 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -74,16 +76,21 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# New doc_idx view. # New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly. # Build the dataset accordingly.
dataset = BertDataset( kwargs = dict(
name=name, name=name,
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
data_prefix=data_prefix, data_prefix=data_prefix,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
masked_lm_prob=masked_lm_prob,
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob, short_seq_prob=short_seq_prob,
seed=seed) seed=seed
)
if ict_dataset:
dataset = InverseClozeDataset(**kwargs)
else:
dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs)
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr) indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks. # Checks.
......
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from .bert_dataset import get_samples_mapping_
class InverseClozeDataset(Dataset):
"""Dataset containing sentences and various 'blocks' for an inverse cloze task."""
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.indexed_dataset = indexed_dataset
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
short_seq_prob,
self.seed,
self.name)
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx + 1000)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_length - 2
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(5, target_seq_length)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_tokens, input_token_types, input_pad_mask = input_data
context_tokens, context_token_types, context_pad_mask = context_data
sample = {
'input_text': np.array(input_tokens),
'input_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
}
return sample
def get_sentence_split_doc(self, idx):
"""fetch document at index idx and split into sentences"""
document = self.indexed_dataset[idx]
if isinstance(document, dict):
document = document['text']
lines = document.split('\n')
return [line for line in lines if line]
def sentence_tokenize(self, sent, sentence_num=0):
"""tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens)
return tokens, token_types
def concat_and_pad_tokens(self, tokens, token_types):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
token_types = [token_types[0]] + token_types + [token_types[0]]
assert len(tokens) <= self.max_seq_length
num_pad = max(0, self.max_seq_length - len(tokens))
pad_mask = [0] * len(tokens) + [1] * num_pad
tokens += [self.pad_id] * num_pad
token_types += [token_types[0]] * num_pad
return tokens, token_types, pad_mask
def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context"""
num_tries = 0
while num_tries < 20:
num_tries += 1
doc = None
while doc is None:
doc_idx = np_rng.randint(len(self) - 1)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
if not doc:
doc = None
# set up and tokenize the entire selected document
num_sentences = len(doc)
padless_max_len = self.max_seq_length - 2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx = rng.randint(0, num_sentences - 1)
tokens, token_types = self.sentence_tokenize(doc[input_sentence_idx], 0)
input_tokens, input_token_types = tokens[:target_seq_length], token_types[:target_seq_length]
if not len(input_tokens) > 0:
continue
context_tokens, context_token_types = [], []
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, keep it out.
if rng.random() < 0.1:
context_tokens = input_tokens.copy()
context_token_types = input_token_types.copy()
# parameters for examining sentences to add to the context
view_preceding = True
view_radius = 1
while len(context_tokens) < padless_max_len:
# keep adding sentences while the context can accommodate more.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
new_tokens, new_token_types = self.sentence_tokenize(doc[examine_idx], 0)
context_tokens = new_tokens + context_tokens
context_token_types = new_token_types + context_token_types
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
new_tokens, new_token_types = self.sentence_tokenize(doc[examine_idx], 0)
context_tokens += new_tokens
context_token_types += new_token_types
view_radius += 1
view_preceding = not view_preceding
if view_radius > num_sentences:
break
# assemble the tokens and token types of the context
context_tokens = context_tokens[:padless_max_len]
context_token_types = context_token_types[:padless_max_len]
if not len(context_tokens) > 0:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
input_tokens, input_token_types)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
...@@ -18,43 +18,32 @@ ...@@ -18,43 +18,32 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from configure_data import configure_data from megatron import get_args
from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel from megatron.model import ICTBertModel
from megatron.utils import print_rank_0 from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
num_batches = 0 num_batches = 0
def model_provider(args): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args()
print_rank_0('building BERT models ...') print_rank_0('building BERT models ...')
model = ICTBertModel( model = ICTBertModel(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
ict_head_size=128, ict_head_size=128,
checkpoint_num_layers=args.checkpoint_num_layers, num_tokentypes=2,
layernorm_epsilon=args.layernorm_epsilon, parallel_output=True)
num_tokentypes=args.tokentype_size,
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
return model return model
def get_batch(data_iterator, timers): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['input_text', 'input_types', 'input_pad_mask', keys = ['input_text', 'input_types', 'input_pad_mask',
...@@ -62,13 +51,10 @@ def get_batch(data_iterator, timers): ...@@ -62,13 +51,10 @@ def get_batch(data_iterator, timers):
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
timers('data loader').start()
if data_iterator is None: if data_iterator is None:
data = None data = None
else: else:
data = next(data_iterator) data = next(data_iterator)
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
...@@ -83,17 +69,17 @@ def get_batch(data_iterator, timers): ...@@ -83,17 +69,17 @@ def get_batch(data_iterator, timers):
context_tokens, context_types, context_pad_mask context_tokens, context_types, context_pad_mask
def forward_step(data_iterator, model, args, timers): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers()
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
input_tokens, input_types, input_pad_mask,\ input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask = get_batch(data_iterator, timers) context_tokens, context_types, context_pad_mask = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
# TODO: important to make sure that everything, including padding mask is as expected here.
retrieval_scores = model(input_tokens, input_pad_mask, input_types, retrieval_scores = model(input_tokens, input_pad_mask, input_types,
context_tokens, context_pad_mask, context_types).float() context_tokens, context_pad_mask, context_types).float()
...@@ -112,50 +98,71 @@ def forward_step(data_iterator, model, args, timers): ...@@ -112,50 +98,71 @@ def forward_step(data_iterator, model, args, timers):
'top5_acc': reduced_losses[2]} 'top5_acc': reduced_losses[2]}
def get_train_val_test_data(args): def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None) (train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
if (args.data_loader == 'raw' print_rank_0('> building train, validation, and test datasets '
or args.data_loader == 'lazy' 'for BERT ...')
or args.data_loader == 'tfrecords'):
data_config = configure_data() data_parallel_size = mpu.get_data_parallel_world_size()
ds_type = 'BERT_ict' data_parallel_rank = mpu.get_data_parallel_rank()
data_config.set_defaults(data_set_type=ds_type, transpose=False) global_batch_size = args.batch_size * data_parallel_size
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args) # Number of train/valid/test samples.
# Need to broadcast num_tokens and num_type_tokens. train_iters = args.train_iters
token_counts = torch.cuda.LongTensor([num_tokens, eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
tokenizer.num_type_tokens, test_iters = args.eval_iters
int(args.do_train), train_val_test_num_samples = [train_iters * global_batch_size,
int(args.do_valid), eval_iters * global_batch_size,
int(args.do_test)]) test_iters * global_batch_size]
else: print_rank_0(' > datasets target sizes (minimum size):')
print("Unsupported data loader for BERT.") print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
exit(1) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
print_rank_0("> finished creating BERT ICT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else: else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(token_counts, torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(), mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item() args.do_train = flags[0].item()
num_type_tokens = token_counts[1].item() args.do_valid = flags[1].item()
args.do_train = token_counts[2].item() args.do_test = flags[2].item()
args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item()
args.vocab_size = num_tokens
args.tokentype_size = num_type_tokens
return train_data, val_data, test_data return train_data, val_data, test_data
if __name__ == "__main__": if __name__ == "__main__":
run('Pretrain ICT BERT model', get_train_val_test_data, pretrain(get_train_val_test_data, model_provider, forward_step,
model_provider, forward_step) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment