Commit 46b2addd authored by Neel Kant's avatar Neel Kant
Browse files

Modify InverseClozeDataset and its usage to use blocks

parent 64aef3a5
...@@ -42,6 +42,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -42,6 +42,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl, data_impl,
skip_warmup) skip_warmup)
if ict_dataset:
titles_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx # Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can # Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it. # easily iterate over it.
...@@ -88,7 +93,10 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -88,7 +93,10 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if ict_dataset: if ict_dataset:
dataset = InverseClozeDataset(**kwargs) titles_idx_ptr = titles_dataset.get_doc_idx()
titles_dataset.set_doc_idx(titles_idx_ptr[start_index:end_index])
dataset = InverseClozeDataset(titles_dataset=titles_dataset, **kwargs)
titles_dataset.set_doc_idx(titles_idx_ptr)
else: else:
dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs) dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs)
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
......
...@@ -11,17 +11,28 @@ from megatron import print_rank_0 ...@@ -11,17 +11,28 @@ from megatron import print_rank_0
from megatron import mpu from megatron import mpu
from megatron.data import helpers from megatron.data import helpers
class InverseClozeDataset(Dataset): class InverseClozeDataset(Dataset):
"""Dataset containing sentences and various 'blocks' for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, context_dataset, titles_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed): short_seq_prob, seed):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.indexed_dataset = indexed_dataset self.context_dataset = context_dataset
self.titles_dataset = titles_dataset
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = get_samples_mapping(self.context_dataset,
self.titles_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
self.seed,
self.name)
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = tokenizer.inv_vocab self.vocab_id_to_token_list = tokenizer.inv_vocab
...@@ -31,21 +42,24 @@ class InverseClozeDataset(Dataset): ...@@ -31,21 +42,24 @@ class InverseClozeDataset(Dataset):
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
def __len__(self): def __len__(self):
return self.indexed_dataset.doc_idx.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair) start_index, end_index, _ = self.samples_mapping[idx]
rng = random.Random(idx + 20000 + self.seed) context = [self.indexed_dataset[i] for i in range(start_index, end_index)]
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) assert len(context) > 1
# get seq length. Save 2 tokens for beginning and end title = self.titles_dataset[idx]
target_seq_length = self.max_seq_length - 2 assert sum(len(c) for c in context) + len(title) <= self.max_seq_length - 3
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(5, target_seq_length) rand_sent_idx = self.rng.randint(0, len(context) - 1)
if self.rng.random() < 0.1:
input = list(context[rand_sent_idx])
else:
input = context.pop(rand_sent_idx)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng) input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input)
input_tokens, input_token_types, input_pad_mask = input_data context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context, title)
context_tokens, context_token_types, context_pad_mask = context_data
sample = { sample = {
'input_text': np.array(input_tokens), 'input_text': np.array(input_tokens),
...@@ -58,19 +72,11 @@ class InverseClozeDataset(Dataset): ...@@ -58,19 +72,11 @@ class InverseClozeDataset(Dataset):
return sample return sample
def get_sentence_split_doc(self, idx): def concat_and_pad_tokens(self, tokens, title=None):
"""fetch document at index idx and split into sentences"""
doc_start = self.indexed_dataset.doc_idx[idx]
doc_end = self.indexed_dataset.doc_idx[idx + 1]
doc_sentences_array = self.indexed_dataset[doc_start:doc_end]
doc_sentences = [list(arr) for arr in doc_sentences_array]
return doc_sentences
def concat_and_pad_tokens(self, tokens):
"""concat with special tokens and pad sequence to self.max_seq_length""" """concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id] tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None:
tokens += title + [self.sep_id]
assert len(tokens) <= self.max_seq_length assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens) num_pad = self.max_seq_length - len(tokens)
...@@ -79,66 +85,83 @@ class InverseClozeDataset(Dataset): ...@@ -79,66 +85,83 @@ class InverseClozeDataset(Dataset):
token_types = [0] * self.max_seq_length token_types = [0] * self.max_seq_length
return tokens, token_types, pad_mask return tokens, token_types, pad_mask
def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context"""
num_tries = 0
while num_tries < 20:
num_tries += 1
doc = None
while doc is None:
doc_idx = np_rng.randint(len(self) - 1)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
if not doc:
doc = None
num_sentences = len(doc)
padless_max_len = self.max_seq_length - 2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx = rng.randint(0, num_sentences - 1)
input_tokens = doc[input_sentence_idx][:target_seq_length]
if not len(input_tokens) > 0:
continue
context_tokens = []
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, keep it out.
if rng.random() < 0.1:
context_tokens = input_tokens.copy()
view_preceding = True
view_radius = 1
while len(context_tokens) < padless_max_len:
# keep adding sentences while the context can accommodate more.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
new_tokens = doc[examine_idx]
context_tokens = new_tokens + context_tokens
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
new_tokens = doc[examine_idx]
context_tokens += new_tokens
view_radius += 1
view_preceding = not view_preceding
if view_radius > num_sentences:
break
# assemble the tokens and token types of the context
context_tokens = context_tokens[:padless_max_len]
if not len(context_tokens) > 0:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input_tokens)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context_tokens)
return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
def get_samples_mapping(context_dataset,
titles_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
seed,
name):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert context_dataset.doc_idx.dtype == np.int64
assert context_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
samples_mapping = helpers.build_blocks_mapping(
context_dataset.doc_idx,
context_dataset.sizes,
titles_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length-3, # account for added tokens
seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment