"src/vscode:/vscode.git/clone" did not exist on "8d81564b27956dbabeeac833139aab27e60e379d"
Commit 2d98cfbf authored by Neel Kant's avatar Neel Kant
Browse files

Merge staging-realm into realm-mlm

parents 8b1da95a 4abd7ce2
# ===========
# base images
# ===========
FROM nvcr.io/nvidia/pytorch:19.09-py3
# ===============
# system packages
# ===============
RUN apt-get update && apt-get install -y \
bash-completion \
emacs \
git \
graphviz \
htop \
libopenexr-dev \
rsync \
wget \
&& rm -rf /var/lib/apt/lists/*
# ============
# pip packages
# ============
RUN pip install --upgrade pip && \
pip install --upgrade setuptools
COPY requirements.txt /tmp/
RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt
boto3
google-cloud-language
inflect
nltk
numpy
pandas
requests
sentencepiece
tensorflow
tqdm
from collections import defaultdict
import pickle
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
model.eval()
dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
hash_data = defaultdict(list)
hash_matrix = np.random.rand(128, 1024)
all_input_tokens = []
all_input_logits = []
all_block_tokens = []
all_block_logits = []
while True:
try:
input_tokens, input_types, input_pad_mask, \
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except StopIteration:
break
input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.concat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1)
for hash, idx in zip(block_hashes, block_indices):
hash_data[int(hash)].append(int(idx))
all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy())
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens)
np.save('block_logits.npy', all_block_logits)
for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices)
hash_data['matrix'] = hash_matrix
with open('hash_data.pkl', 'wb') as hash_file:
pickle.dump(hash_data, hash_file)
def load_checkpoint():
args = get_args()
model = get_model(model_provider)
if isinstance(model, torchDDP):
model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_dataset():
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
kwargs = dict(
name='full',
context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
return dataset
def get_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
main()
...@@ -24,7 +24,6 @@ from torch.utils.data import Dataset ...@@ -24,7 +24,6 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data import helpers
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.ict_dataset import InverseClozeDataset from megatron.data.ict_dataset import InverseClozeDataset
...@@ -43,7 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -43,7 +42,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
skip_warmup) skip_warmup)
if ict_dataset: if ict_dataset:
titles_dataset = get_indexed_dataset_(data_prefix + '-titles', title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl, data_impl,
skip_warmup) skip_warmup)
...@@ -55,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -55,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits. # Print stats about the splits.
print_rank_0(' > dataset split:') print_rank_0(' > dataset split:')
def print_split_stats(name, index): def print_split_stats(name, index):
print_rank_0(' {}:'.format(name)) print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} ' print_rank_0(' document indices in [{}, {}) total of {} '
...@@ -83,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -83,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Build the dataset accordingly. # Build the dataset accordingly.
kwargs = dict( kwargs = dict(
name=name, name=name,
context_dataset=indexed_dataset,
data_prefix=data_prefix, data_prefix=data_prefix,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
...@@ -93,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -93,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if ict_dataset: if ict_dataset:
dataset = InverseClozeDataset(titles_dataset=titles_dataset, **kwargs) dataset = InverseClozeDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else: else:
dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs) dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr) indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks. # Checks.
...@@ -261,6 +268,7 @@ def get_samples_mapping_(indexed_dataset, ...@@ -261,6 +268,7 @@ def get_samples_mapping_(indexed_dataset,
start_time = time.time() start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format( print_rank_0(' > building sapmles index mapping for {} ...'.format(
name)) name))
from megatron.data import helpers
samples_mapping = helpers.build_mapping( samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
......
...@@ -272,7 +272,7 @@ def create_masked_lm_predictions(tokens, ...@@ -272,7 +272,7 @@ def create_masked_lm_predictions(tokens,
for idx in range(len(cand_indexes)): for idx in range(len(cand_indexes)):
ngram_index = [] ngram_index = []
for n in ngrams: for n in ngrams:
ngram_index.append(cand_indexes[idx:idx+n]) ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index) ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes) np_rng.shuffle(ngram_indexes)
...@@ -406,12 +406,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -406,12 +406,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels)
# Tokens and token types. # Tokens and token types.
filler = [pad_id]*padding_length filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64) tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask. # Padding mask.
padding_mask_np = np.array([1]*num_tokens + [0]*padding_length, padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64) dtype=np.int64)
# Lables and loss mask. # Lables and loss mask.
......
...@@ -13,124 +13,305 @@ ...@@ -13,124 +13,305 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 dataset.""" """GPT2 style dataset."""
import json
import os import os
import numpy as np import time
import numpy as np
import torch import torch
from torch.utils.data import Dataset
from megatron import print_rank_0
from megatron import mpu
class GPT2Dataset(Dataset): from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def __init__(self, data_path, sizes_filename, seq_length,
initial_seed, max_epochs=100):
# Input parameters. def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
self.data_path = data_path train_valid_test_num_samples,
self.sizes_filename = sizes_filename seq_length, seed, skip_warmup):
self.seq_length = seq_length """Build train, valid, and test datasets."""
self.initial_seed = initial_seed
self.max_epochs = max_epochs # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
# Shard stuff. data_impl,
# Dictionary from shard nameto its size (number of element). skip_warmup)
self.master_shard_size_dict = None
# Dictionary from shard name to modified size so it is total_num_of_documents = indexed_dataset.sizes.shape[0]
# divisible by self.seq_length. splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
self.shard_size_dict = None
# Long array (self.max_epochs * num-shards) populated # Print stats about the splits.
# randomly with shard names. print_rank_0(' > dataset split:')
self.shards_name = None def print_split_stats(name, index):
# Start index of the data for a shard. print_rank_0(' {}:'.format(name))
self.shards_start_index = None print_rank_0(' document indices in [{}, {}) total of {} '
self.build_shard_mappings_() 'documents'.format(splits[index], splits[index + 1],
self.data_length = self.shards_start_index[-1] splits[index + 1] - splits[index]))
print_split_stats('train', 0)
# Data. print_split_stats('validation', 1)
self.shards_data = [None]*self.shards_name.size print_split_stats('test', 2)
self.shards_sample_index = [None]*self.shards_name.size
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index+1],
step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self): def __len__(self):
return self.data_length # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx): def __getitem__(self, idx):
# Find which shard we need. # Get the shuffled index.
shard_index = np.searchsorted(self.shards_start_index, idx = self.shuffle_idx[idx]
idx, side='right') - 1 # Start and end documents and offsets.
# data index in the shard. doc_index_f = self.sample_idx[idx][0]
data_idx = idx - self.shards_start_index[shard_index] doc_index_l = self.sample_idx[idx+1][0]
# Load the shard if it is not in memory. offset_f = self.sample_idx[idx][1]
if self.shards_data[shard_index] is None: offset_l = self.sample_idx[idx+1][1]
print('global rank {} is building data for shard index {} ...'. # If we are within the same document, just extract the chunk.
format(torch.distributed.get_rank(), shard_index)) if doc_index_f == doc_index_l:
self.build_dataset_(shard_index) sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
#assert self.shards_data[shard_index] is not None offset=offset_f,
# Start index. length=offset_l - offset_f + 1)
start_index = self.shards_sample_index[shard_index][data_idx] else:
# Add one for label shift. # Otherwise, get the rest of the initial document.
end_index = start_index + self.seq_length + 1 sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
data = self.shards_data[shard_index][start_index:end_index] offset=offset_f)]
return {'text': np.array(data, dtype=np.int64)} # Loop over all in between documents and add the entire document.
for i in range(doc_index_f+1, doc_index_l):
def build_dataset_(self, shard_index): sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# Garbage collect so we don't use a lot of memory. # And finally add the relevant portion of last document.
# Leave the last one in case other threads have not catche up yet. sample_list.append(self.indexed_dataset.get(
#for i in range(shard_index - 1): self.doc_idx[doc_index_l],
for i in range(shard_index): length=offset_l+1))
self.shards_data[i] = None sample = np.concatenate(sample_list)
self.shards_sample_index[i] = None
# Read the shard. return {'text': np.array(sample, dtype=np.int64)}
filename = os.path.join(self.data_path, self.shards_name[shard_index])
print('loading {}'.format(filename))
data = np.load(filename, allow_pickle=True)
# Shuffle the data def _build_index_mappings(name, data_prefix, documents, sizes,
rng = np.random.RandomState(self.initial_seed + shard_index) num_samples, seq_length, seed):
rng.shuffle(data) """Build doc-idx, sample-idx, and shuffle-idx.
# Flatten. doc-idx: is an array (ordered) of documents to be used in training.
data = np.hstack(data) sample-idx: is the start document index and document offset for each
size = (data.shape[0] - 1) // self.seq_length training sample.
last_index = size * self.seq_length + 1 shuffle-idx: maps the sample index into a random index into sample-idx.
data = data[0:last_index] """
self.shards_data[shard_index] = data # Number of tokens in each epoch and number of required epochs.
indices = np.arange(size) * self.seq_length tokens_per_epoch = _num_tokens(documents, sizes)
rng.shuffle(indices) num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
self.shards_sample_index[shard_index] = indices # rng state
np_rng = np.random.RandomState(seed=seed)
def build_shard_mappings_(self):
# Load the sizes file. # Filename of the index mappings.
sizes_filename = os.path.join(self.data_path, self.sizes_filename) _filename = data_prefix
if torch.distributed.get_rank() == 0: _filename += '_{}_indexmap'.format(name)
print(' > loading sizes from {}'.format(sizes_filename)) _filename += '_{}ns'.format(num_samples)
with open(sizes_filename, 'r') as f: _filename += '_{}sl'.format(seq_length)
self.master_shard_size_dict = json.load(f) _filename += '_{}s'.format(seed)
if torch.distributed.get_rank() == 0: doc_idx_filename = _filename + '_doc_idx.npy'
print(' found {} shards'.format(len(self.master_shard_size_dict))) sample_idx_filename = _filename + '_sample_idx.npy'
# Adjust sizes to be a multiple of seq_length. shuffle_idx_filename = _filename + '_shuffle_idx.npy'
self.shard_size_dict = self.master_shard_size_dict.copy()
total_samples = 0 # Build the indexed mapping if not exist.
for shard in self.shard_size_dict:
size = self.shard_size_dict[shard]
size = ((size - 1) // self.seq_length) * self.seq_length
total_samples += size // self.seq_length
self.shard_size_dict[shard] = size
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(' found {} samples in the dataset'.format(total_samples)) if (not os.path.isfile(doc_idx_filename)) or \
# Build a list of shards. (not os.path.isfile(sample_idx_filename)) or \
shards_ = np.sort(np.array(list(self.shard_size_dict.keys()))) (not os.path.isfile(shuffle_idx_filename)):
rng = np.random.RandomState(self.initial_seed)
self.shards_name = np.copy(shards_) print_rank_0(' > WARNING: could not find index map files, building '
rng.shuffle(self.shards_name) 'the indices on rank 0 ...')
for i in range(1, self.max_epochs): # doc-idx.
shards_c = np.copy(shards_) start_time = time.time()
rng.shuffle(shards_c) doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
self.shards_name = np.append(self.shards_name, shards_c) np.save(doc_idx_filename, doc_idx, allow_pickle=True)
# Build the global indexing. print_rank_0(' > elasped time to build and save doc-idx mapping '
self.shards_start_index = np.zeros(self.shards_name.size, dtype=np.int) '(seconds): {:4f}'.format(time.time() - start_time))
self.shards_start_index[0] = 0 # sample-idx.
for i in range(1, self.shards_name.size): start_time = time.time()
shard = str(self.shards_name[i-1]) # Use C++ implementation for speed.
size = self.shard_size_dict[shard] from megatron.data import helpers
self.shards_start_index[i] = self.shards_start_index[i-1] + \ assert doc_idx.dtype == np.int32
size // self.seq_length assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True)
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True)
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
def _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 1] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(size, np_rng):
"""Build the range [0, size) and shuffle."""
dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx)
return shuffle_idx
...@@ -33,6 +33,95 @@ using namespace std; ...@@ -33,6 +33,95 @@ using namespace std;
const int32_t LONG_SENTENCE_LEN = 512; const int32_t LONG_SENTENCE_LEN = 512;
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio, inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
...@@ -516,4 +605,5 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_, ...@@ -516,4 +605,5 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
PYBIND11_MODULE(helpers, m) { PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping); m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping); m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
} }
...@@ -15,14 +15,14 @@ from megatron.data import helpers ...@@ -15,14 +15,14 @@ from megatron.data import helpers
class InverseClozeDataset(Dataset): class InverseClozeDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, context_dataset, titles_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed): short_seq_prob, seed):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.context_dataset = context_dataset self.block_dataset = block_dataset
self.titles_dataset = titles_dataset self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
...@@ -41,38 +41,37 @@ class InverseClozeDataset(Dataset): ...@@ -41,38 +41,37 @@ class InverseClozeDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.titles_dataset[int(doc_idx)]) block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
context = [list(self.context_dataset[i]) for i in range(start_idx, end_idx)] assert len(block) > 1
assert len(context) > 1
# avoid selecting the first or last sentence to be the query. # avoid selecting the first or last sentence to be the query.
if len(context) == 2: if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5) rand_sent_idx = int(self.rng.random() > 0.5)
else: else:
rand_sent_idx = self.rng.randint(1, len(context) - 2) rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the context 10% of the time. # keep the query in the block 10% of the time.
if self.rng.random() < 0.1: if self.rng.random() < 0.1:
input = context[rand_sent_idx].copy() query = block[rand_sent_idx].copy()
else: else:
input = context.pop(rand_sent_idx) query = block.pop(rand_sent_idx)
# may still need to truncate because blocks are concluded when # still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length. # the sentence lengths have exceeded max_seq_length.
input = input[:self.max_seq_length - 2] query = query[:self.max_seq_length - 2]
context = list(itertools.chain(*context))[:self.max_seq_length - (3 + len(title))] block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input) query_tokens, query_token_types, query_pad_mask = self.concat_and_pad_tokens(query)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context, title) block_tokens, block_token_types, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = { sample = {
'input_text': np.array(input_tokens), 'query_tokens': np.array(query_tokens),
'input_types': np.array(input_token_types), 'query_types': np.array(query_token_types),
'input_pad_mask': np.array(input_pad_mask), 'query_pad_mask': np.array(query_pad_mask),
'context_text': np.array(context_tokens), 'block_tokens': np.array(block_tokens),
'context_types': np.array(context_token_types), 'block_types': np.array(block_token_types),
'context_pad_mask': np.array(context_pad_mask), 'block_pad_mask': np.array(block_pad_mask),
'context_indices': np.array([block_idx]) 'block_indices': np.array([start_idx, end_idx, doc_idx, block_idx])
} }
return sample return sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment