Commit e4575be9 authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
import os
import time
import numpy as np
import torch
from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, \
pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=0,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
drop_last=False)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
'context_tokens', 'context_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_mask = data_b['query_mask'] < 0.5
context_tokens = data_b['context_tokens'].long()
context_mask = data_b['context_mask'] < 0.5
block_indices = data_b['block_data'].long()
return query_tokens, query_mask,\
context_tokens, context_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Blendable dataset."""
import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
class BlendableDataset(torch.utils.data.Dataset):
def __init__(self, datasets, weights):
self.datasets = datasets
num_datasets = len(datasets)
assert num_datasets == len(weights)
self.size = 0
for dataset in self.datasets:
self.size += len(dataset)
# Normalize weights.
weights = np.array(weights, dtype=np.float64)
sum_weights = np.sum(weights)
assert sum_weights > 0.0
weights /= sum_weights
# Build indecies.
start_time = time.time()
assert num_datasets < 255
self.dataset_index = np.zeros(self.size, dtype=np.uint8)
self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)
from megatron.data import helpers
helpers.build_blending_indices(self.dataset_index,
self.dataset_sample_index,
weights, num_datasets, self.size,
torch.distributed.get_rank() == 0)
print_rank_0('> elapsed time for building blendable dataset indices: '
'{:.2f} (sec)'.format(time.time() - start_time))
def __len__(self):
return self.size
def __getitem__(self, idx):
dataset_idx = self.dataset_index[idx]
sample_idx = self.dataset_sample_index[idx]
return self.datasets[dataset_idx][sample_idx]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataloaders."""
import torch
from megatron import get_args
from megatron import mpu
from megatron.data.decoder_packed_mtf_dataset import DecoderPackedMTFDataset
def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
"""Buld dataloader given an input dataset."""
if dataset is None:
return None
args = get_args()
# Megatron sampler
if args.dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
if num_workers is None:
num_workers = args.num_workers
# Torch dataloader.
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=None,
pin_memory=True
)
class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.consumed_samples < self.total_samples, \
'no samples left to consume: {}, {}'.format(self.consumed_samples,
self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, and NVIDIA.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import math
import os
import time
import collections
import numpy as np
import torch
from megatron import (
get_args,
mpu,
print_rank_0
)
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_BERT = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPE_T5 = 't5'
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
def analyze_data_prefix(data_prefix):
# The data prefix should be in the format of:
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
assert len(data_prefix) % 2 == 0
num_datasets = len(data_prefix) // 2
weights = [0]*num_datasets
prefixes = [0]*num_datasets
for i in range(num_datasets):
weights[i] = float(data_prefix[2*i])
prefixes[i] = (data_prefix[2*i+1]).strip()
# Normalize weights
weight_sum = 0.0
for weight in weights:
weight_sum += weight
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
return prefixes, weights
def get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples):
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
prefixes, weights = analyze_data_prefix(data_prefix)
datasets_train_valid_test_num_samples = []
for weight in weights:
datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples])
return prefixes, weights, datasets_train_valid_test_num_samples
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
# Number of sentences in the sample.
n_sentences = len(sample)
# Make sure we always have two sentences.
assert n_sentences > 1, 'make sure each sample has at least two sentences.'
# First part:
# `a_end` is how many sentences go into the `A`.
a_end = 1
if n_sentences >= 3:
# Note that randin in numpy is exclusive.
a_end = np_rng.randint(1, n_sentences)
tokens_a = []
for j in range(a_end):
tokens_a.extend(sample[j])
# Second part:
tokens_b = []
for j in range(a_end, n_sentences):
tokens_b.extend(sample[j])
# Random next:
is_next_random = False
if np_rng.random() < 0.5:
is_next_random = True
tokens_a, tokens_b = tokens_b, tokens_a
return tokens_a, tokens_b, is_next_random
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert len_a > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
if len_a > len_b:
len_a -= 1
tokens = tokens_a
else:
len_b -= 1
tokens = tokens_b
if np_rng.random() < 0.5:
del tokens[0]
else:
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
tokens = []
tokentypes = []
# [CLS].
tokens.append(cls_id)
tokentypes.append(0)
# Segment A.
for token in tokens_a:
tokens.append(token)
tokentypes.append(0)
# [SEP].
tokens.append(sep_id)
tokentypes.append(0)
# Segment B.
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
def create_masked_lm_predictions(tokens,
vocab_id_list, vocab_id_to_token_dict,
masked_lm_prob,
cls_id, sep_id, mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False,
geometric_dist=False,
masking_style="bert"):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
if not geometric_dist:
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
(masked_lms, masked_spans) = ([], [])
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
if not geometric_dist:
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
if masking_style == "bert":
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
elif masking_style == "t5":
masked_token = mask_id
else:
raise ValueError("invalid value of masking style")
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_spans.append(MaskedLmInstance(
index=index_set,
label=[tokens[index] for index in index_set]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
# Sort the spans by the index of the first span
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length):
"""Pad sequences and convert them to numpy."""
# Some checks.
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask.
padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64)
# Lables and loss mask.
labels = [-1] * max_seq_length
loss_mask = [0] * max_seq_length
for i in range(len(masked_positions)):
assert masked_positions[i] < num_tokens
labels[masked_positions[i]] = masked_labels[i]
loss_mask[masked_positions[i]] = 1
labels_np = np.array(labels, dtype=np.int64)
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob, short_seq_prob, seed,
skip_warmup, binary_head=False,
max_seq_length_dec=None,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
max_seq_length_dec,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length,
masked_lm_prob, short_seq_prob, seed,
skip_warmup, binary_head,
max_seq_length_dec,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset
from megatron.data.ict_dataset import ICTDataset
from megatron.data.t5_dataset import T5Dataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed,
)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.use_one_sent_docs,
binary_head=binary_head,
**kwargs
)
elif dataset_type == DSET_TYPE_T5:
dataset = T5Dataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
max_seq_length_dec=max_seq_length_dec,
short_seq_prob=short_seq_prob,
**kwargs
)
elif dataset_type == DSET_TYPE_BERT:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
binary_head=binary_head,
**kwargs
)
else:
raise NotImplementedError("Dataset type not fully implemented.")
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_split_by_range_(range_string, size):
""" Get dataset splits based on a range:
range_string is in the form START%:END% for e.g. 0.2:0.8
outputs an array of two values [start_index, end_index]
"""
# some checks that range is given in the correct form
splits = [float(i) for i in range_string.split(":")]
assert len(splits) == 2, "splits should be passed as start:end"
assert splits[0] <= 1 and splits[1] <= 1
splits_sum = sum(splits)
assert splits_sum > 0.0
splits_index = [round(s * float(size)) for s in splits]
assert len(splits_index) == 2
return splits_index
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name,
binary_head):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
# First compile and then import.
from megatron.data import helpers
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
import os
import time
import numpy as np
import torch
from megatron import print_rank_0, mpu, logging
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \
get_train_valid_test_split_
from megatron.data.mtf_dataset import MTFDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
logger = logging.get_logger(__name__)
def build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
seq_length: int,
pad_token: int,
eos_token: int,
train_valid_test_num_samples,
seed,
skip_warmup
):
"""Build train, valid, and test datasets."""
# Single dataset.
if len(data_prefix) == 1:
all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(
data_prefix=data_prefix[0],
data_impl=data_impl,
splits_string=splits_string,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
train_valid_test_num_samples=train_valid_test_num_samples,
seed=seed,
skip_warmup=skip_warmup
)
# Blending dataset.
else:
output = get_datasets_weights_and_num_samples(data_prefix=data_prefix, train_valid_test_num_samples=train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
data_prefix=prefixes[i],
data_impl=data_impl,
splits_string=splits_string,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
seed=seed,
skip_warmup=skip_warmup
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
all_train_datasets = BlendableDataset(train_datasets, weights) \
if train_datasets else None
all_valid_datasets = BlendableDataset(valid_datasets, weights) \
if valid_datasets else None
all_test_datasets = BlendableDataset(test_datasets, weights) \
if test_datasets else None
return all_train_datasets, all_valid_datasets, all_test_datasets
def build_dataset_group(
dataset_group_name,
paths,
weights,
splits,
data_impl,
seq_length: int,
pad_token: int,
eos_token: int,
train_valid_test_num_samples,
seed,
skip_warmup,
train_valid_test
):
'''
Build a single dataset group corresponding to Option 2 of data loading see arguments.py
a dataset group is passed in the following form
GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2
or alternatively
GIVEN_NAME PATH1 # for a single dataset to be used fully
'''
assert train_valid_test in ["train","valid","test"]
# Single dataset.
if len(paths) == 1:
dataset = _build_single_datasets(
data_prefix=paths[0],
range_string=splits[0],
data_impl=data_impl,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
train_valid_test_num_samples=train_valid_test_num_samples,
seed=seed,
skip_warmup=skip_warmup,
dataset_group_name=dataset_group_name,
train_valid_test=train_valid_test
)
return dataset
# Blending dataset.
else:
data_prefix = []
# data_prefix is of the shape:
# ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"]
for w,p in zip(weights, paths):
data_prefix += [w,p]
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
datasets = []
for i in range(len(prefixes)):
ds = _build_single_datasets(
data_prefix=prefixes[i],
range_string=splits[i],
data_impl=data_impl,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
seed=seed,
skip_warmup=skip_warmup,
dataset_group_name=dataset_group_name,
train_valid_test=train_valid_test
)
datasets.append(ds)
all_datasets = BlendableDataset(datasets, weights)
return all_datasets
def _build_single_datasets(
data_prefix,
range_string,
data_impl,
seq_length: int,
pad_token: int,
eos_token: int,
train_valid_test_num_samples,
seed,
skip_warmup,
dataset_group_name,
train_valid_test
):
"""Build a single dataset"""
assert train_valid_test in ["train","valid","test"]
index = ["train","valid","test"].index(train_valid_test)
# Target indexed dataset.
target_indexed_dataset = get_indexed_dataset(
data_prefix=data_prefix,
is_input=False,
data_impl=data_impl,
skip_warmup=skip_warmup
)
total_num_of_documents = target_indexed_dataset.sizes.shape[0]
# this corresponds to option2 for data loading on the form
# WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3
# splits here is an array of size 2 [start_index, end_index]
splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
print_rank_0(' {}:'.format(dataset_group_name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[0], splits[1],
splits[1] - splits[0]))
def build_dataset(name):
dataset = None
if splits[1] > splits[0]:
documents = np.arange(start=splits[0], stop=splits[1],
step=1, dtype=np.int32)
dataset = DecoderPackedMTFDataset(
name=name,
data_prefix=data_prefix,
data_impl=data_impl,
skip_warmup=skip_warmup,
documents=documents,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
num_samples=train_valid_test_num_samples[index],
seed=seed
)
return dataset
dataset = build_dataset(dataset_group_name)
return dataset
def _build_train_valid_test_datasets(
data_prefix,
data_impl,
splits_string,
seq_length: int,
pad_token: int,
eos_token: int,
train_valid_test_num_samples,
seed,
skip_warmup
):
"""Build train, valid, and test datasets."""
# Target indexed dataset.
target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup)
total_num_of_documents = target_indexed_dataset.sizes.shape[0]
# splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = DecoderPackedMTFDataset(
name=name,
data_prefix=data_prefix,
data_impl=data_impl,
skip_warmup=skip_warmup,
documents=documents,
seq_length=seq_length,
pad_token=pad_token,
eos_token=eos_token,
num_samples=train_valid_test_num_samples[index],
seed=seed
)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
class DecoderPackedMTFDataset(torch.utils.data.Dataset):
def __init__(
self,
name,
data_prefix,
data_impl,
skip_warmup,
documents,
num_samples,
seq_length: int,
pad_token: int,
eos_token: int,
seed,
):
self.mtf_dataset = MTFDataset(name=name, data_prefix=data_prefix, data_impl=data_impl, skip_warmup=skip_warmup, documents=documents)
self.pad_token = pad_token
self.seq_length = seq_length
self.sample_index, self.shuffle_index = _build_index_mappings(name=name, data_prefix=data_prefix, nb_documents=len(documents), mtf_dataset=self.mtf_dataset, num_samples=num_samples, seq_length=seq_length, seed=seed)
def __len__(self):
return len(self.sample_index)
def __getitem__(self, idx):
# Get the shuffled index.
start, end = self.sample_index[idx]
mtf_samples_indices = self.shuffle_index[start: end]
# TODO @thomasw21 build a dataset that generates an entire batch instead of a row (allows for more optimization)
items = [self.mtf_dataset[sample_id] for sample_id in mtf_samples_indices]
return self.pack_samples(items)
def pack_samples(self, items):
"""
Greedily packs samples.
Items:
[
{
'input_tokens': array([6, 7]),
'target_tokens': array([8])
},
{
'input_tokens': array([3, 4]),
'target_tokens': array([5])
}
]
Output:
decoder_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]: `1` depicts inputs, `0` depicts target.
"""
decoder_tokens = np.full((self.seq_length,), self.pad_token, dtype=np.int64)
decoder_segment_ids = np.zeros((self.seq_length,), dtype=np.int64)
decoder_is_inputs = np.full((self.seq_length,), False, dtype=bool)
# `0` is reserved for padding
item_num = 1
cur_len = 0
assert len(items) > 0
for token_dict in items:
input_token_len = len(token_dict["input_tokens"])
target_token_len = len(token_dict["target_tokens"])
total_len = input_token_len + target_token_len
if cur_len + total_len > self.seq_length:
# This should not happen at the indexing should only allow the correct number of items
raise ValueError(f"""Items to be packed do not fit inside a single sample.
current length: {cur_len}
input tokens length: {input_token_len}
target token length: {target_token_len}
expected sequence length: {self.seq_length}
""")
decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[cur_len: cur_len + total_len] = item_num
decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs
# targets are already 0 at init, no need to update `decoder_is_inputs`
item_num += 1
cur_len += total_len
assert cur_len <= self.seq_length
return {
"decoder_token_ids": decoder_tokens,
"decoder_segment_ids": decoder_segment_ids,
"decoder_is_inputs": decoder_is_inputs,
}
def _build_index_mappings(
name,
data_prefix,
nb_documents,
mtf_dataset,
num_samples: int,
seq_length: int,
seed,
):
"""
- `shuffle_index` is [num_epoch * len(self.mtf)]
- `sample_index` is [num_sample, 2] (storing the start and end of the sample). We query the sample via `self.shuffle_index[start:end]`
TODO @thomas21 Instead of loading individually samples, we save the packing one and for all
"""
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}s'.format(seed)
sample_idx_filename = _filename + '_decoder_packed_batch_idx.npy'
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
# iteratively add the entire dataset for every epoch and see if it's enough given current packing strategy
start_time = time.time()
row_offset = 0
old_sample_start = 0
epoch = 0
shuffle_idx = []
sample_idx = []
while len(sample_idx) <= num_samples:
new_document_ids = _build_shuffle_idx(nb_documents=nb_documents, np_rng=np_rng)
# Generate a shuffling of the entire dataset
shuffle_idx.append(new_document_ids)
# Packs them into a single sample
new_samples, row_offset, old_sample_start = _build_sample_idx(
mtf_dataset=mtf_dataset,
document_ids=new_document_ids,
seq_length=seq_length,
row_offset=row_offset,
old_sample_start=old_sample_start,
epoch=epoch
)
sample_idx.extend(new_samples)
epoch += 1
shuffle_idx = np.concatenate(shuffle_idx, axis=0)
sample_idx = np.stack(sample_idx, axis=0)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
return sample_idx, shuffle_idx
def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sample_start, epoch):
"""Build start and off index of each `full` batch, return that list of batch + start of the unfinished batch"""
row_length = row_offset
full_samples = []
current_sample_start = old_sample_start
epoch_offset = epoch * len(document_ids)
assert epoch_offset >= current_sample_start
for current_sample_end, document_id in enumerate(document_ids):
current_sample_end = epoch_offset + current_sample_end
sample_sizes = mtf_dataset.size(document_id)
# TODO @thomasw21 figure out if we add <eos> tokens
tok_len = sample_sizes["input_tokens"] + sample_sizes["target_tokens"]
row_length = row_length + tok_len
if row_length > seq_length:
# current sample can't be added and requires to be added in the next one
if current_sample_end > current_sample_start:
full_samples.append(np.asarray([current_sample_start, current_sample_end]))
current_sample_start = current_sample_end
row_length = tok_len
if tok_len > seq_length:
# TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can
# - silently skip that value [currently implemented]
# - truncate to `seq_length`, and keep the right part
logger.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}")
current_sample_start = current_sample_end + 1 # skipping
row_length = 0
continue
return full_samples, row_length, current_sample_start
def _build_shuffle_idx(nb_documents: int, np_rng):
"""Build the range [0, dataset_size) and shuffle."""
dtype_ = np.int64
result = np.arange(start=0, stop=nb_documents, step=1, dtype=dtype_)
# in-place shuffling
np_rng.shuffle(result)
return result
def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool):
if is_input:
field = "inputs"
else:
field = "targets"
return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup)
def get_indexed_dataset_(path, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(path,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.distributed as dist
class DistDataError(Exception):
"""Defines an empty exception to throw when some other rank hit a real exception."""
pass
class DistData(object):
def __init__(self, backend='gloo'):
assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'"
dist.init_process_group(backend, init_method="env://")
# lookup our process rank and the group size
self.rank = dist.get_rank()
self.numranks = dist.get_world_size()
def allassert(self, cond, msg):
"""Check that cond is True on all ranks, assert with msg everywhere if not.
To prevent deadlocks in cases where an assertion might only fail on one rank,
this executes an allreduce to ensure that if any rank finds that an assertion
has been violated, all ranks fail an assertion check.
The condition must be true on all ranks for this not to assert.
"""
alltrue = self.alltrue(cond)
assert alltrue, msg
def allraise_if(self, err):
"""Raise exception if err is not None on any rank.
Similarly to allassert, this raises an exception on all ranks if err
is set to an exception on any rank. Rank(s) where err is not None
re-raise err as exception, and ranks where err is None raise DistDataError.
Thus all ranks raise an exception if any rank has an active exception,
which helps avoid deadlocks in cases where an exception may be raised
on a subset of ranks.
"""
alltrue = self.alltrue(err is None)
if not alltrue:
# At least one rank raised an exception.
# Re-raise the actual exception if this rank threw one.
if err is not None:
raise err
# TODO: is there a better exception to use here?
# On other ranks, raise an "empty" exception to indicate
# that we're only failing because someone else did.
raise DistDataError
def barrier(self):
"""Globally synchronize all processes"""
dist.barrier()
def bcast(self, val, root):
"""Broadcast a scalar value from root to all ranks"""
vals = [val]
dist.broadcast_object_list(vals, src=root)
return vals[0]
def scatterv_(self, invals: np.array, counts: list, root:int=0):
"""Scatter int64 values from invals according to counts array, return received portion in a new tensor"""
self.allassert(len(counts) == self.numranks,
f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}")
# Define list of tensors to scatter on the root.
# torch.distributed.scatter requires each tensor to be the same shape,
# so find the max size across all count values and pad.
max_size = max(counts)
scatterlist = None
if self.rank == root:
slices = list(torch.split(torch.from_numpy(invals), counts))
scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices]
# Receive a tensor of the max count size from the root,
# then copy values into output numpy array, which may be smaller.
recvtensor = torch.zeros(max_size, dtype=torch.int64)
dist.scatter(recvtensor, scatterlist, src=root)
return recvtensor[:counts[self.rank]]
def alltrue(self, val):
"""Returns True if all procs input True, False otherwise"""
# torch.dist does not support reductions with bool types
# so we cast to int and cast the result back to bool
tensor = torch.tensor([int(val)], dtype=torch.int32)
dist.all_reduce(tensor, op=dist.ReduceOp.BAND)
return bool(tensor[0])
def sum(self, val):
"""Compute sum of a scalar val, and return total on all ranks."""
tensor = torch.tensor([val])
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return tensor[0]
def exscan(self, val: int):
"""Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank."""
# torch.distributed doesn't have a scan, so fallback to allreduce
tensor = torch.zeros(self.numranks, dtype=torch.int64)
tensor[self.rank:] = val
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return int(tensor[self.rank]) - val
def min(self, val):
"""Return minimum of scalar val to all ranks."""
tensor = torch.tensor([val])
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
return tensor[0]
def minrank(self, cond):
"""Find first rank whose condition is True, return that rank if any, None otherwise."""
minrank = self.numranks
if cond:
minrank = self.rank
minrank = self.min(minrank)
if minrank < self.numranks:
return minrank
return None
def bcast_first(self, val):
"""Broadcast val from first rank where it is not None, return val if any, None otherwise"""
# Find the first rank with a valid value.
minrank = self.minrank(val is not None)
# If there is no rank with a valid value, return None
if minrank is None:
return None
# Otherwise broadcast the value from the first valid rank.
val = self.bcast(val, root=minrank)
return val
def all_sum_(self, vals: np.array):
"""Sums values in numpy array vals element-wise and update vals in place with final result on all ranks"""
# Builds torch.tensor with from_numpy to use same underlying memory as numpy array.
tensor = torch.from_numpy(vals)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
def open(self, filename, truncate=None):
"""Create, truncate, and open a file shared by all ranks."""
# Don't truncate existing file until all ranks reach this point
self.barrier()
# We'll capture any exception in this variable
err = None
# Rank 0 creates and truncates file.
if self.rank == 0:
try:
f = open(filename, 'wb')
# Some file systems like GPFS deliver faster write speed
# if the file size is known before data is written to the file.
if truncate is not None:
f.truncate(truncate)
except Exception as e:
err = e
# Verify that rank 0 created the file
self.allraise_if(err)
# Wait for rank 0 to open (and truncate) file,
# then have all ranks open file for writing.
if self.rank != 0:
try:
f = open(filename, 'r+b')
except Exception as e:
err = e
# Verify that all ranks successfully opened the file
self.allraise_if(err)
return f
def remove(self, filename):
"""Remove a shared file."""
# Don't remove the file until all are ready
self.barrier()
# We'll capture any exception in this variable
err = None
# Rank 0 removes the file if it exists.
if self.rank == 0:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
err = e
# Verify that rank 0 successfully removed the file.
self.allraise_if(err)
def rename(self, srcfile, destfile):
"""Rename a shared file."""
# Don't rename until all are ready
self.barrier()
# We'll capture any exception in this variable
err = None
# Rank 0 renames the file.
if self.rank == 0:
try:
if os.path.exists(srcfile):
os.rename(srcfile, destfile)
except Exception as e:
err = e
# Verify that the rename succeeded
self.allraise_if(err)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT style dataset."""
import os
import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_, get_split_by_range_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Single dataset.
if len(data_prefix) == 1:
all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(data_prefix[0],
data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup)
# Blending dataset.
else:
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
all_train_datasets = BlendableDataset(train_datasets, weights) \
if train_datasets else None
all_valid_datasets = BlendableDataset(valid_datasets, weights) \
if valid_datasets else None
all_test_datasets = BlendableDataset(test_datasets, weights) \
if test_datasets else None
return all_train_datasets, all_valid_datasets, all_test_datasets
def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl,
train_valid_test_num_samples,
seq_length, seed, skip_warmup, train_valid_test):
'''
Build a single dataset group corresponding to Option 2 of data loading see arguments.py
a dataset group is passed on the following form
GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2
or alternatively
GIVEN_NAME PATH1 # for a single dataset to be used fully
'''
assert train_valid_test in ["train","valid","test"]
# Single dataset.
if len(paths) == 1:
dataset = _build_single_datasets(paths[0],
splits[0],
data_impl,
train_valid_test_num_samples,
seq_length, seed, skip_warmup,
dataset_group_name, train_valid_test)
return dataset
# Blending dataset.
else:
data_prefix = []
# data_prefix is on the shape:
# ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"]
for w,p in zip(weights, paths):
data_prefix += [w,p]
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
datasets = []
for i in range(len(prefixes)):
ds = _build_single_datasets(prefixes[i],
splits[i],
data_impl,
datasets_train_valid_test_num_samples[i],
seq_length,
seed, skip_warmup,
dataset_group_name, train_valid_test)
datasets.append(ds)
all_datasets = BlendableDataset(datasets, weights)
return all_datasets
def _build_single_datasets(data_prefix, range_string, data_impl, train_valid_test_num_samples,
seq_length, seed, skip_warmup, dataset_group_name, train_valid_test):
"""Build a single dataset"""
assert train_valid_test in ["train","valid","test"]
index = ["train","valid","test"].index(train_valid_test)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
# this corresponds to option2 for data loading on the form
# WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3
# splits here is an array of size 2 [start_index, end_index]
splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
print_rank_0(' {}:'.format(dataset_group_name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[0], splits[1],
splits[1] - splits[0]))
def build_dataset(name):
dataset = None
if splits[1] > splits[0]:
documents = np.arange(start=splits[0], stop=splits[1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
dataset = build_dataset(dataset_group_name)
return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
# splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPTDataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(path, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(path,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPTDataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self):
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1)
else:
# Otherwise, get the rest of the initial document.
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l],
length=offset_l + 1))
sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples, seq_length, seed, cutoff_last_epoch=0.95):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
sample-idx: is the start document index and document offset for each
training sample.
shuffle-idx: maps the sample index into a random index into sample-idx.
"""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}sl'.format(seq_length)
_filename += '_{}s'.format(seed)
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# If we need only one epoch, then separating last epoch does
# not mean anything.
if num_epochs == 1:
separate_last_epoch = False
print(' > only one epoch required, setting '
'separate_last_epoch to False', flush=True)
else:
# Get the number of samples for the last epoch
num_samples_from_epochs_minus_one = (
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
last_epoch_num_samples = num_samples - \
num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, \
f'last epoch number of samples {last_epoch_num_samples} should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples <= num_samples_per_epoch, \
f'last epoch number of samples {last_epoch_num_samples} exceeded max value {num_samples_per_epoch}.'
# If we have less than cutoff_last_epoch * samples_per_epoch of the samples for the last epoch,
# seperate out the epoch and treat it differently.
separate_last_epoch = (last_epoch_num_samples <
int(cutoff_last_epoch * num_samples_per_epoch))
if separate_last_epoch:
string = ' > last epoch number of samples ({}) is smaller '\
'than {}% of number of samples per epoch ({}), '\
'setting separate_last_epoch to True'
else:
string = ' > last epoch number of samples ({}) is larger '\
'than {}% of number of samples per epoch ({}), '\
'setting separate_last_epoch to False'
print(string.format(last_epoch_num_samples, cutoff_last_epoch * 100,
num_samples_per_epoch), flush=True)
# doc-idx.
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# sample-idx.
start_time = time.time()
# Use C++ implementation for speed.
# First compile and then import.
from megatron.data import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
if separate_last_epoch:
num_samples_ = num_samples_from_epochs_minus_one
else:
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_,
sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
if not separate_last_epoch or num_epochs == 1:
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))
def _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 1] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle."""
print(' > building shuffle index with split [0, {}) and [{}, {}) '
'...'.format(num_samples, num_samples, total_size), flush=True)
dtype_ = np.uint32
if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx_first = np.arange(start=0, stop=num_samples,
step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = np.arange(start=num_samples, stop=total_size,
step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
/*
coding=utf-8
Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/* Helper methods for fast index mapping builds */
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
namespace py = pybind11;
using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;
void build_blending_indices(py::array_t<uint8_t>& dataset_index,
py::array_t<int64_t>& dataset_sample_index,
const py::array_t<double>& weights,
const int32_t num_datasets,
const int64_t size, const bool verbose) {
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if (verbose) {
std::cout << "> building indices for blendable datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
int64_t current_samples[num_datasets];
for(int64_t i = 0; i < num_datasets; ++i) {
current_samples[i] = 0;
}
// For each sample:
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(current_samples[dataset_idx]);
if (error > max_error) {
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<uint8_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
// Update the total samples.
current_samples[max_error_index] += 1;
}
// print info
if (verbose) {
std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937& rand32_gen) {
/* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1);
}
return max_length;
}
template<typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const double short_seq_prob,
const int32_t seed,
const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int.
int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " short sequence probability: " << short_seq_prob <<
endl << std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and it's length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the seed so both iterations produce the same results.
std::mt19937 rand32_gen(seed);
// Set the flag on second iteration.
second = (iteration == 1);
// Counters:
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// Current map index.
uint64_t map_index = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent > 1) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have more than two sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
auto target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow.
if ((3 * map_index + 2) >
std::numeric_limits<int64_t>::max()) {
cout << "number of samples exceeded maximum "
<< "allowed by type int64: "
<< std::numeric_limits<int64_t>::max()
<< endl;
throw std::overflow_error("Number of samples");
}
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
}
// Update indices / counters.
++map_index;
prev_start_index = sent_index + 1;
target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const double short_seq_prob,
const int seed,
const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
}
}
template<typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& titles_sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
auto titles_sizes = titles_sizes_.unchecked<1>();
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and its length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the flag on second iteration.
second = (iteration == 1);
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
const auto target_seq_len = max_seq_length - titles_sizes[doc];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const py::array_t<int>& titles_sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_blending_indices", &build_blending_indices);
}
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
)
dataset = ICTDataset(**kwargs)
return dataset
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return len(self.samples_mapping)
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# copied from fairseq/fairseq/data/indexed_dataset.py
# Removed IndexedRawTextDataset since it relied on Fairseq dictionary
# other slight modifications to remove fairseq dependencies
# Added document index to index file and made it accessible.
# An empty sentence no longer separates documents.
from functools import lru_cache
import os
import stat
import shutil
import struct
from itertools import accumulate
import numpy as np
import torch
from megatron import print_rank_0
def best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:
return np.uint16
else:
return np.int32
def get_available_dataset_impl():
return ['lazy', 'cached', 'mmap']
def infer_dataset_impl(path):
if IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
else:
return None
else:
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
def make_builder(out_file, impl, dtype=None):
if impl == 'mmap':
assert dtype is not None
return MMapIndexedDatasetBuilder(out_file, dtype=dtype)
else:
assert dtype is None
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer':
impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None
def dataset_exists(path, impl):
if impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
8: np.uint16
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
def create_doc_idx(sizes):
doc_idx = [0]
for i, s in enumerate(sizes):
if s == 0:
doc_idx.append(i + 1)
return doc_idx
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path):
super().__init__()
self.path = path
self.data_file = None
self.read_index(path)
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.doc_count = struct.unpack('<Q', f.read(8))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
self.doc_idx = read_longs(f, self.doc_count)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self._len:
raise IndexError('index out of range')
def __del__(self):
if self.data_file:
self.data_file.close()
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if not self.data_file:
self.read_data(self.path)
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
return a
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
size = sum(sizes)
a = np.empty(size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[start] * self.element_size)
self.data_file.readinto(a)
offsets = list(accumulate(sizes))
sents = np.split(a, offsets[:-1])
return sents
def __len__(self):
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path):
super().__init__(path)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx: ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
if self.data_file:
# close and delete data file after prefetch so we can pickle
self.data_file.close()
self.data_file = None
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
i = idx
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx: ptx + a.size])
return a
elif isinstance(idx, slice):
# Hack just to make this work, can optimizer later if necessary
sents = []
for i in range(*idx.indices(len(self))):
sents.append(self[i])
return sents
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.uint16: 2,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
@staticmethod
def write_header(fout, dtype, numdata, numsize, numdoc):
"""Writes header for cached indexed dataset to given file handle, return number of bytes written."""
startpos = fout.tell()
fout.write(IndexedDataset._HDR_MAGIC)
fout.write(struct.pack('<Q', 1))
fout.write(struct.pack('<Q', code(dtype)))
fout.write(struct.pack('<Q', IndexedDatasetBuilder.element_sizes[dtype]))
fout.write(struct.pack('<Q', numdata - 1))
fout.write(struct.pack('<Q', numsize))
fout.write(struct.pack('<Q', numdoc))
endpos = fout.tell()
return endpos - startpos
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
self.doc_idx = [0]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def end_document(self):
self.doc_idx.append(len(self.sizes))
def merge_file_(self, another_file):
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
doc_offset = len(self.sizes)
begin = self.data_offsets[-1]
for data_offset in index.data_offsets[1:]:
self.data_offsets.append(begin + data_offset)
self.sizes.extend(index.sizes)
begin = self.dim_offsets[-1]
for dim_offset in index.dim_offsets[1:]:
self.dim_offsets.append(begin + dim_offset)
self.doc_idx.extend( (doc_offset + index.doc_idx)[1:] )
with open(data_file_path(another_file), 'rb') as f:
while True:
data = f.read(1024)
if data:
self.out_file.write(data)
else:
break
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
IndexedDatasetBuilder.write_header(index, self.dtype, len(self.data_offsets), len(self.sizes), len(self.doc_idx))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
write_longs(index, self.doc_idx)
index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
def exscan_from_cumsum_(arr):
# given an array holding the result of an inclusive scan (cumsum),
# convert to an exclusive scan (shift to the right)
# [10, 30, 35, 50] --> [0, 10, 30, 35]
if arr.size > 1:
arr[1:] = arr[:-1]
if arr.size > 0:
arr[0] = 0
def get_pointers_with_total(sizes, elemsize, dtype):
"""Return a numpy array of type np.dtype giving the byte offsets.
Multiplies values in the sizes array by elemsize (bytes),
and then computes an exclusive scan to get byte offsets.
Returns the total number of bytes as second item in a tuple.
"""
# scale values in sizes array by elemsize to get sizes in bytes
pointers = np.array(sizes, dtype=dtype)
pointers *= elemsize
np.cumsum(pointers, axis=0, out=pointers)
# get total number of bytes from all sizes (last element)
bytes_last = pointers[-1] if len(sizes) > 0 else 0
# convert to byte offsets
exscan_from_cumsum_(pointers)
return pointers, bytes_last
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@staticmethod
def write_header(fout, dtype, numsizes, numdocs):
"""Writes header for mmap indexed dataset to given file handle, return number of bytes written."""
startpos = fout.tell()
fout.write(MMapIndexedDataset.Index._HDR_MAGIC)
fout.write(struct.pack('<Q', 1))
fout.write(struct.pack('<B', code(dtype)))
fout.write(struct.pack('<Q', numsizes))
fout.write(struct.pack('<Q', numdocs))
endpos = fout.tell()
return endpos - startpos
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
return self
@staticmethod
def _get_pointers(sizes, npdtype):
"""Return a numpy array of byte offsets given a list of sizes.
Multiplies values in the sizes array by dtype size (bytes),
and then computes an exclusive scan to get byte offsets.
"""
# compute element sizes in bytes
pointers, _ = get_pointers_with_total(sizes, dtype().itemsize, npdtype)
return pointers
def write(self, sizes, doc_idx):
MMapIndexedDataset.Index.write_header(self._file, dtype, len(sizes), len(doc_idx))
sizes32 = np.array(sizes, dtype=np.int32)
self._file.write(sizes32.tobytes(order='C'))
del sizes32
pointers = self._get_pointers(sizes, np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order='C'))
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path, skip_warmup=False):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
self._doc_count = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
if not skip_warmup:
print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
@property
def doc_idx(self):
return self._doc_idx
@lru_cache(maxsize=8)
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path, skip_warmup=False):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path, skip_warmup)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path, skip_warmup):
self._path = path
self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup:
print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path))
print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def __del__(self):
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
del self._index
def __len__(self):
return len(self._index)
# @lru_cache(maxsize=8)
def __getitem__(self, idx):
if isinstance(idx, int):
ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=size, offset=ptr)
return np_array
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
ptr = self._index._pointers[start]
sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes))
total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1])
return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property
def sizes(self):
return self._index.sizes
def size(self, index):
return self._index.sizes[index]
@property
def doc_idx(self):
return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
def dtype(self):
return self._index.dtype
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
self._doc_idx = [0]
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def end_document(self):
self._doc_idx.append(len(self._sizes))
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
total_len = len(index.sizes)+len(self._sizes)
print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}")
offset = len(self._sizes)
self._sizes.extend(index.sizes)
self._doc_idx.extend( (offset + index.doc_idx)[1:] )
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes, self._doc_idx)
# To merge a set of binary files, one can simply concatenate them in order.
# We stat each binary file to determine its size, execute a scan to compute
# the byte offset where the calling rank should write its data, seek to proper
# spot, and copy each file.
def gather_files_dist_bin(outfile, filelist, distctx):
"""Concatenate binary files in filelist into a new file given by outfile"""
# lookup size of each of our binary files
filesizes = [os.stat(data_file_path(f))[stat.ST_SIZE] for f in filelist]
# compute total bytes of the merged file and the offset
# at which this rank will write data from its files
numbytes = sum(filesizes)
count = distctx.sum(numbytes)
offset = distctx.exscan(numbytes)
# We first write to a temporary file name. We rename to the final name
# if successful or delete the temporary file if not.
# This way if the final name appears, the user knows it's a valid file.
finalname = data_file_path(outfile)
finalnametmp = finalname + ".tmp"
# First delete the final file if it already exists
distctx.remove(finalname)
# Catch I/O errors from any process
err = None
try:
# Create shared output file and pre-truncate to its final size.
with distctx.open(finalnametmp, truncate=count) as fout:
# Seek to appropriate starting offset in the merged file.
fout.seek(offset)
# Copy in contents of each of our files.
for f in filelist:
with open(data_file_path(f), "rb") as fsrc:
shutil.copyfileobj(fsrc, fout)
except Exception as e:
err = e
# Check that all ranks wrote successfully.
# This will raise an exception all on ranks if we detect
# an exception on any rank.
distctx.allraise_if(err)
# Everyone wrote their part successfully.
# Rename the temporary file to the final file.
distctx.rename(finalnametmp, finalname)
def write_list_at_offset(fout, file_offset, vals, shift, elem_offset, dtype):
"""Write list of vals to fout starting at an offset given by file_offset, elem_offset, and dtype.
Copies list of values in vals to a numpy array of type dtype.
Adds a constant shift value to all elements.
Writes the numpy array to the file handle at given offset and scaled by size of the datatype.
offset = file_offset + elem_offset * dtype().itemsize
Parameters
----------
fout : file handle
Open file handle to which to write list of vals
file_offset : int
Byte offset within the file where the global list starts
vals : list[int]
List of values to be written
shift : int
Value to add to each element in vals before writing (use 0 for no change)
elem_offset : int
Zero-based element index where vals starts within the global list.
This value is scaled by dtype().itemsize to convert to a corresponding byte offset.
dtype : np.dtype
numpy datatype to be used when writing the list to the file
"""
# Make a copy of the vals list using the requested datatype.
npvals = np.array(vals, dtype=dtype)
# Shift values in the list by a constant value.
npvals += shift
# Seek to proper offset for this rank and write
# values into file, stored as given datatype.
fout.seek(file_offset + elem_offset * dtype().itemsize)
fout.write(npvals.tobytes(order='C'))
def gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_code, distctx):
# Verify that no rank has found an inconsistent value in its own set of files.
# This includes an allreduce to verify that dtype_rank_consistent is True everywhere.
distctx.allassert(dtype_rank_consistent, "Some rank found inconsistent dtype values")
# Verify that at least one rank found a dtype value.
# Because of the bcast, the the value of first_dtype_code is the same on all ranks.
first_dtype_code = distctx.bcast_first(dtype_code)
assert first_dtype_code is not None, "Failed to find a dtype value in any index file"
# Verify that the dtype is consistent on all ranks, if a rank has a dtype value.
distctx.allassert(dtype_code == first_dtype_code or dtype_code is None, "Different dtype values detected in index files")
# return the dtype
return dtypes[first_dtype_code]
def gather_files_dist_idx_cached(outfile, filelist, distctx):
# Read each index file and append items to our lists
sizes = []
data_offsets = [0]
dim_offsets = [0]
doc_idx = [0]
dtype_rank_consistent = True # whether this rank identifies inconsistent dtype values in its files
dtype_value = None # the current dtype code to compare against, if any
for f in filelist:
# read index file for this file
index = IndexedDataset(f)
# append its size, data, dim, and doc entries to our lists
doc_offset = len(sizes)
sizes.extend(index.sizes)
data_offsets.extend(index.data_offsets[1:] + data_offsets[-1])
dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1])
doc_idx.extend(index.doc_idx[1:] + doc_offset)
# check that the dtype in this index matches the dtype in our other files
dtype_code = code(index.dtype)
if dtype_value is None:
dtype_value = dtype_code
if dtype_value != dtype_code:
dtype_rank_consistent = False
# Check that we have consistent dtypes in all files from all ranks,
# and return the dtype being used.
dtype = gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_value, distctx)
# Capture the last value in the data array before we delete any items.
# Note this may be zero on any rank that has no items,
# but zero is the correct value in that case.
# We use this last value to compute a shift value that
# is later be added to each element in our data list.
data_shift = distctx.exscan(data_offsets[-1])
# Drop the zero entry from the lists that start with
# a "0" value unless we're rank 0.
if distctx.rank != 0:
del data_offsets[0]
del dim_offsets[0]
del doc_idx[0]
# Compute total number of entires in data, size, dim,
# and doc_idx lists across all ranks. Also compute the offset
# of the calling rank for each list considering the number
# of entries for all ranks before the calling rank.
numdata = len(data_offsets)
numsize = len(sizes)
numdim = len(dim_offsets)
numdoc = len(doc_idx)
global_data_count = distctx.sum(numdata)
global_size_count = distctx.sum(numsize)
global_dim_count = distctx.sum(numdim)
global_doc_count = distctx.sum(numdoc)
global_data_offset = distctx.exscan(numdata)
global_size_offset = distctx.exscan(numsize)
global_dim_offset = distctx.exscan(numdim)
global_doc_offset = distctx.exscan(numdoc)
# We first write to a temporary file name. We rename to the final name
# if successful or delete the temporary file if not.
# This way if the final name appears, the user knows it's a valid file.
finalname = index_file_path(outfile)
finalnametmp = finalname + ".tmp"
# First delete the final file if it already exists
distctx.remove(finalname)
# Catch and I/O errors to later determine whether all ranks wrote successfully.
err = None
try:
# Create shared output file
with distctx.open(finalnametmp) as fout:
# Have rank 0 write the file header
file_offset = 0
if distctx.rank == 0:
try:
file_offset = fout.tell()
file_offset += IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count)
except Exception as e:
err = e
distctx.allraise_if(err)
# Broadcast current file position from rank 0.
file_offset = distctx.bcast(file_offset, root=0)
# The dimension list records the offset within
# the sizes list for each sentence.
# We shift our dimension index values to account for the number of size values
# that come before the calling rank which is in global_size_offset.
write_list_at_offset(fout, file_offset, dim_offsets, global_size_offset, global_dim_offset, np.int64)
file_offset += global_dim_count * np.int64().itemsize
# The data index records the element offset to the start of each
# sentence within the binary data file. Note that this is an
# element offset, not a byte offset. Each element is pyhsically stored
# in the data file as dtype().itemsize bytes.
# We shift the data index values according to the number of elements that
# come before the calling rank, which is stored in data_shift.
write_list_at_offset(fout, file_offset, data_offsets, data_shift, global_data_offset, np.int64)
file_offset += global_data_count * np.int64().itemsize
# Each sentence is stored as a tensor.
# The tensor for each sentence can be multidimensional.
# The number of tensor dimensions per sentence is variable,
# and the size of each dimension of a sentence is arbitrary.
# The size list records a flattened list of the sizes
# for each dimension of a sentence.
# No shift value is needed.
write_list_at_offset(fout, file_offset, sizes, 0, global_size_offset, np.int64)
file_offset += global_size_count * np.int64().itemsize
# The document index records the offset within the sizes
# array for the first sentence of each document.
# We shift the document index values for number of size values that
# come before the calling rank which is in global_size_offset.
write_list_at_offset(fout, file_offset, doc_idx, global_size_offset, global_doc_offset, np.int64)
file_offset += global_doc_count * np.int64().itemsize
except Exception as e:
# if we encounter any exception while writing, store it for later
err = e
# Check that all ranks wrote successfully
distctx.allraise_if(err)
# Everyone wrote their part successfully.
# Rename the temporary file to the final file.
distctx.rename(finalnametmp, finalname)
def gather_files_dist_idx_mmap(outfile, filelist, distctx):
# Read each index file and append items to the size and doc_idx lists
sizes = []
doc_idx = [0]
dtype_rank_consistent = True # whether rank identifies inconsistent dtype values in its files
dtype_value = None # the current dtype code to compare against, if any
for f in filelist:
# read index file for this file
index = MMapIndexedDataset.Index(index_file_path(f))
# append its size and doc entries to our lists
docs_offset = len(sizes)
sizes.extend(index.sizes)
doc_idx.extend(index.doc_idx[1:] + docs_offset)
# check that the dtype in this index matches the dtype in our other files
dtype_code = code(index.dtype)
if dtype_value is None:
dtype_value = dtype_code
if dtype_value != dtype_code:
dtype_rank_consistent = False
# Check that we have consistent dtypes in all files from all ranks,
# and return the dtype being used.
dtype = gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_value, distctx)
# Drop the zero entry from the lists that start with
# a "0" value unless we're rank 0
if distctx.rank != 0:
del doc_idx[0]
# Compute total number of size and document index
# values across all ranks. Also compute the offset
# of the calling rank for each value considering
# the values of sizes/docs for all ranks before the
# calling rank.
numsizes = len(sizes)
numdocs = len(doc_idx)
global_size_count = distctx.sum(numsizes)
global_docs_count = distctx.sum(numdocs)
global_size_offset = distctx.exscan(numsizes)
global_docs_offset = distctx.exscan(numdocs)
# Compute local byte offsets for each of our sentences given
# the token count and byte size of the vocab dtype.
pointers, pointers_bytes = get_pointers_with_total(sizes, dtype().itemsize, np.int64)
# Determine total number of bytes for all sentences on ranks
# before the calling rank.
pointer_offset = distctx.exscan(pointers_bytes)
# We first write to a temporary file name. We rename to the final name
# if successful or delete the temporary file if not.
# This way if the final name appears, the user knows it's a valid file.
finalname = index_file_path(outfile)
finalnametmp = finalname + ".tmp"
# First delete the final file if it already exists
distctx.remove(finalname)
# Catch and I/O errors to later determine whether all ranks wrote successfully.
err = None
try:
# Create shared output file
with distctx.open(finalnametmp) as fout:
# Have rank 0 write the file header
file_offset = 0
if distctx.rank == 0:
try:
file_offset = fout.tell()
file_offset += MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count)
except Exception as e:
err = e
distctx.allraise_if(err)
# Broadcast current file position from rank 0.
file_offset = distctx.bcast(file_offset, root=0)
# The list of size values from each rank are
# concatenated and stored as int32.
write_list_at_offset(fout, file_offset, sizes, 0, global_size_offset, np.int32)
file_offset += global_size_count * np.int32().itemsize
# The pointer values store the byte offset to each sentence when in memory.
# A sentence has a variable number of tokens, given by
# its corresponding entry in the size array. Each token
# of a sentence is stored in units of type dtype, which consumes
# dtype().itemsize bytes (often a standard type that is just
# large enough to represent all elements of the vocabulary).
# Since the pointers array is the same length as the sizes array,
# we use global_size_offset and global_size_count to position
# within the file for writing the pointer values.
write_list_at_offset(fout, file_offset, pointers, pointer_offset, global_size_offset, np.int64)
file_offset += global_size_count * np.int64().itemsize
# The document index points to the position in the sizes
# array for the starting sentence of each document.
# A variable number of sentences can be in each document.
# We shift the document index for number of sentences that
# come before the calling rank which is in global_size_offset.
write_list_at_offset(fout, file_offset, doc_idx, global_size_offset, global_docs_offset, np.int64)
file_offset += global_docs_count * np.int64().itemsize
except Exception as e:
# if we encounter any exception while writing, store it for later
err = e
# Check that all ranks wrote successfully
distctx.allraise_if(err)
# Everyone wrote their part successfully.
# Rename the temporary file to the final file.
distctx.rename(finalnametmp, finalname)
# Verify that all files in filelist are of the same index type.
# Returns the identified type {cached, mmap} as a string.
def gather_files_dist_check_impltype(filelist, distctx):
# Sanity check for typos in file names.
# Check that a data file exists for each of our files.
all_files_exist = all([os.path.exists(data_file_path(f)) for f in filelist])
# Check that all ranks have all of their files.
distctx.allassert(all_files_exist, "Some rank is missing its input file")
# map type string to an integer for easier bcast, use 0 for unknown
implmap = {"cached": 1, "mmap": 2}
# check that all files in filelist are of the same type
sametype = True
ourtype = None
for f in filelist:
# read header of index file to determine its type
impl = infer_dataset_impl(f)
implval = implmap[impl] if impl in implmap else 0
# check that the type matches our other files
if ourtype is None:
ourtype = implval
if ourtype != implval:
sametype = False
# Check that all ranks have the same type,
# and that there is no unknown type.
# This checks that:
# - all of our own files (if any) are of the same type AND
# - either we have no files or the type of our files match the broadcast type AND
# - the broadcast type is of a known type: {cached, mmap}
bcasttype = distctx.bcast_first(ourtype)
matchtype = sametype and (ourtype is None or ourtype == bcasttype) and bcasttype != 0
distctx.allassert(matchtype, "Cannot merge dataset files of different types")
# map back to return index string name
for key in implmap.keys():
if implmap[key] == bcasttype:
return key
# raise exception if key for bcasttype was not found
raise UnreachableCode
def gather_files_dist(filemain, filelist, distctx):
"""Collectively merge files into a new output file specified in filemain.
Each rank contributes a distinct list of zero or more files in filelist,
and each rank directly merges its set of files into filemain.
It is allowed for the input files in filelist to only be readable from the calling process.
In particular, the input files specified by the calling process may be in storage
that only the calling process can access, like /dev/shm or a node-local SSD.
The output file in filemain should be in a location that is writable by all processes.
NOTE: This uses parallel writes to a shared file to achieve high write bandwidth.
To do so, this implementation seeks beyond the end of the file to write at different
offsets from different processes via the seek() method on a python file handle.
The behavior of seek() is not well documented, but it seems to map to fseek()/lseek(),
and it works as desired on POSIX-compliant file systems like Lustre and GPFS."""
# Check that at least one input file is listed
filecount = distctx.sum(len(filelist))
assert filecount > 0, "All ranks have no input files to merge"
# Check that files are all of the same index type
indexstr = gather_files_dist_check_impltype(filelist, distctx)
# Concatenate the data files
gather_files_dist_bin(filemain, filelist, distctx)
# Combine index files into a single index file
if indexstr == "cached":
gather_files_dist_idx_cached(filemain, filelist, distctx)
elif indexstr == "mmap":
gather_files_dist_idx_mmap(filemain, filelist, distctx)
def get_start_end(count, rank, numranks):
"""Return (start, end) index values for calling rank to evenly divide count items among numranks.
Example usage:
start, end = get_start_end(len(itemlist), distctx.rank, distctx.numranks)
sublist = itemlist[start:end]
Parameters
----------
count : int
Total number of items to be divided
rank : int
Rank of the calling process, within range of [0, numranks)
numranks : int
Number of ranks by which to divide count items
Returns
----------
(start, end) : tuple(int)
Start and end index values that define the [start, end) range for rank
"""
num, remainder = divmod(count, numranks)
if rank < remainder:
start = (num + 1) * rank
end = start + num + 1
else:
start = (num + 1) * remainder + num * (rank - remainder)
end = start + num
return start, end
def merge_files_dist(filemain, filelist, distctx):
"""Merge list of indexed datasets into a single indexed dataset named in filemain.
Given a list of indexed datasets in filelist, and the set of processes defined
by the distributed environment in distctx, collectively merge files into
a new, single output indexed dataset named in filemain. This overwrites filemain
if it already exists. It does not delete the input datasets in filelist. The input
parameters filemain and filelist must be identical on all calling processes,
and all processes in distctx must call this method collectively.
It requires that all ranks be able to read any file in filelist, and all
ranks must be able to write to the single output file named in filemain."""
# TODO: if file sizes vary significantly, it might be better to consider
# file size when splitting the list to different ranks.
# evenly divide list of files among ranks
start, end = get_start_end(len(filelist), distctx.rank, distctx.numranks)
sublist = filelist[start:end]
# delegate merge to gather implementation
return gather_files_dist(filemain, sublist, distctx)
"""Non-Causal Mask Language Model Finetune Style dataset."""
import numpy as np
import torch
from megatron import print_rank_0, get_tokenizer, get_args
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_
from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_
from megatron.data.gpt_dataset import GPTDataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
sequence_length,
noise_density,
mean_noise_span_length,
seed,
skip_warmup
):
assert noise_density is not None
assert mean_noise_span_length is not None
if len(data_prefix) == 1:
return _build_train_valid_test_datasets(
data_prefix=data_prefix[0],
data_impl=data_impl,
splits_string=splits_string,
train_valid_test_num_samples=train_valid_test_num_samples,
sequence_length=sequence_length,
noise_density=noise_density,
mean_noise_span_length=mean_noise_span_length,
seed=seed,
skip_warmup=skip_warmup
)
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
train_datasets = []
valid_datasets = []
test_datasets = []
for i in range(len(prefixes)):
train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
data_prefix=prefixes[i],
data_impl=data_impl,
splits_string=splits_string,
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
sequence_length=sequence_length,
noise_density=noise_density,
mean_noise_span_length=mean_noise_span_length,
seed=seed,
skip_warmup=skip_warmup
)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
def build_dataset_group(
dataset_group_name,
paths,
weights,
splits,
data_impl,
train_valid_test_num_samples,
seq_length,
noise_density,
mean_noise_span_length,
seed,
skip_warmup,
train_valid_test
):
'''
Build a single dataset group corresponding to Option 2 of data loading see arguments.py
a dataset group is passed on the following form
GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2
or alternatively
GIVEN_NAME PATH1 # for a single dataset to be used fully
'''
assert train_valid_test in ["train","valid","test"]
# Single dataset.
if len(paths) == 1:
dataset = _build_single_datasets(
data_prefix=paths[0],
range_string=splits[0],
data_impl=data_impl,
train_valid_test_num_samples=train_valid_test_num_samples,
sequence_length=seq_length,
noise_density=noise_density,
mean_noise_span_length=mean_noise_span_length,
seed=seed,
skip_warmup=skip_warmup,
dataset_group_name=dataset_group_name,
train_valid_test=train_valid_test)
return dataset
# Blending dataset.
else:
data_prefix = []
# data_prefix is on the shape:
# ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"]
for w,p in zip(weights, paths):
data_prefix += [w,p]
output = get_datasets_weights_and_num_samples(data_prefix,
train_valid_test_num_samples)
prefixes, weights, datasets_train_valid_test_num_samples = output
# Build individual datasets.
datasets = []
for i in range(len(prefixes)):
ds = _build_single_datasets(
data_prefix=prefixes[i],
range_string=splits[i],
data_impl=data_impl,
train_valid_test_num_samples=datasets_train_valid_test_num_samples[i],
sequence_length=seq_length,
noise_density=noise_density,
mean_noise_span_length=mean_noise_span_length,
seed=seed,
skip_warmup=skip_warmup,
dataset_group_name=dataset_group_name,
train_valid_test=train_valid_test
)
datasets.append(ds)
all_datasets = BlendableDataset(datasets, weights)
return all_datasets
def _build_single_datasets(
data_prefix,
range_string,
data_impl,
train_valid_test_num_samples,
sequence_length,
noise_density,
mean_noise_span_length,
seed,
skip_warmup,
dataset_group_name,
train_valid_test):
"""Build a single dataset"""
assert train_valid_test in ["train","valid","test"]
index = ["train","valid","test"].index(train_valid_test)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
# this corresponds to option2 for data loading on the form
# WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3
# splits here is an array of size 2 [start_index, end_index]
splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
print_rank_0(' {}:'.format(dataset_group_name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[0], splits[1],
splits[1] - splits[0]))
def build_dataset(name):
dataset = None
if splits[1] > splits[0]:
documents = np.arange(start=splits[0], stop=splits[1],
step=1, dtype=np.int32)
dataset = MLMDataset(
indexed_dataset=indexed_dataset,
documents=documents,
noise_density=noise_density,
mean_noise_span_length=mean_noise_span_length,
name=name,
data_prefix=data_prefix,
sequence_length=sequence_length,
num_samples=train_valid_test_num_samples[index],
seed=seed,
)
return dataset
dataset = build_dataset(dataset_group_name)
return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
sequence_length,
noise_density,
mean_noise_span_length,
seed,
skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Build the dataset accordingly.
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = MLMDataset(
indexed_dataset=indexed_dataset,
documents=documents,
noise_density=noise_density,
mean_noise_span_length=mean_noise_span_length,
name=name,
data_prefix=data_prefix,
sequence_length=sequence_length,
num_samples=train_valid_test_num_samples[index],
seed=seed,
)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
class MLMDataset(torch.utils.data.Dataset):
def __init__(
self,
name,
indexed_dataset,
documents,
data_prefix,
sequence_length,
num_samples,
seed,
noise_density=0.15,
mean_noise_span_length=3
):
# Params to store.
self.name = name
self.seed = seed
self.sequence_length = sequence_length
# Dataset.
self.indexed_dataset = indexed_dataset
self.noise_density = noise_density
self.mean_noise_span_length = mean_noise_span_length
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
# To ensure that the input length is `sequence_length`, we need to increase the maximum length
# according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly.
number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths(
sequence_length=self.sequence_length,
noise_density=self.noise_density,
mean_noise_span_length=self.mean_noise_span_length
)
self.inputs_length = inputs_length
# In order to compute loss, we need an extra token at the end.
self.number_of_raw_tokens = number_of_raw_tokens + 1
self.targets_length = targets_length + 1
self.num_noise_spans = num_noise_spans
# Build the samples mapping.
self._gpt_dataset = GPTDataset(
name=self.name,
data_prefix=data_prefix,
documents=documents,
indexed_dataset=self.indexed_dataset,
num_samples=num_samples,
# -1 because GPTDataset will return `seq_length + 1` sequences.
seq_length=self.number_of_raw_tokens - 1,
seed=seed
)
# Vocab stuff.
tokenizer = get_tokenizer()
self.sep_id = tokenizer.sep
self.sentinel_token_ids = tokenizer.additional_special_tokens_ids
assert self.sep_id is not None, "MLM dataset requires tokenizer to have a <sep> token"
assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more"
args = get_args()
# TODO @thomasw21 check once we merge t5
assert self.inputs_length + self.targets_length == args.seq_length + 1
def __len__(self):
return len(self._gpt_dataset)
def __getitem__(self, idx):
if isinstance(idx, slice):
raise NotImplementedError
sample = self._gpt_dataset[idx]["text"]
return build_training_sample(
sample=sample,
inputs_length=self.inputs_length,
targets_length=self.targets_length,
num_noise_spans=self.num_noise_spans,
sep_id=self.sep_id,
all_sentinel_token_ids=self.sentinel_token_ids,
)
def build_training_sample(
sample,
inputs_length,
targets_length,
num_noise_spans,
sep_id,
all_sentinel_token_ids,
):
"""Build training sample.
Arguments:
sample: int32 tensor
inputs_length: integer
targets_length: integer
num_noise_spans: integer
sep_id: integer
all_sentinel_token_ids: List[int]
Returns:
Dict with following keys:
- `input_tokens`: int32 tensor with as length input_length,
- `target_tokens`: int32 tensor with as length targets_length + 1,
"""
spans_start, mask_indices = random_spans_noise_mask(
inputs_length=inputs_length,
targets_length=targets_length,
num_noise_spans=num_noise_spans,
)
spans_end = np.concatenate([
spans_start[1:], np.full((1,), len(sample), dtype=np.int32)]
)
sentinel_token_ids = all_sentinel_token_ids[:num_noise_spans]
input_token_ids = np.concatenate(
[
elt
for start, end, sentinel_token in zip(spans_start[::2], spans_end[::2], sentinel_token_ids)
for elt in [sample[start: end], np.full((1,), sentinel_token, dtype=np.int32)]
] +
[np.full((1,), sep_id, dtype=np.int32)]
)
target_token_ids = np.concatenate(
[
elt
for start, end, sentinel_token in zip(spans_start[1::2], spans_end[1::2], sentinel_token_ids)
for elt in [np.full((1,), sentinel_token, dtype=np.int32), sample[start: end]]
] +
[np.full((1,), sep_id, dtype=np.int32)]
)
return {
'input_tokens': input_token_ids,
'target_tokens': target_token_ids
}
def compute_input_and_target_lengths(sequence_length, noise_density, mean_noise_span_length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
Training parameters to avoid padding with random_spans_noise_mask.
When training a model with random_spans_noise_mask, we would like to set the other
training hyperparmeters in a way that avoids padding.
This function helps us compute these hyperparameters.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
This function tells us the required number of tokens in the raw example (for split_tokens())
as well as the length of the encoded targets. Note that this function assumes
the inputs and targets will have SEP appended and includes that in the reported length.
Args:
inputs_length: an integer - desired length of the tokenized inputs sequence
noise_density: a float
mean_noise_span_length: a float
Returns:
tokens_length: length of original text in tokens
targets_length: an integer - length in tokens of encoded targets sequence
"""
def _tokens_length_to_inputs_length_targets_length(_tokens_length):
num_noise_tokens = int(round(_tokens_length * noise_density))
num_nonnoise_tokens = _tokens_length - num_noise_tokens
_num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
# inputs contain all nonnoise tokens, sentinels for all noise spans and one SEP token.
_input_length = num_nonnoise_tokens + _num_noise_spans + 1
_output_length = num_noise_tokens + _num_noise_spans + 1
return _input_length, _output_length, _num_noise_spans
tokens_length = sequence_length
inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length)
while inputs_length + targets_length > sequence_length:
tokens_length -= 1
inputs_length, targets_length, num_noise_spans = _tokens_length_to_inputs_length_targets_length(tokens_length)
# tokens_length is the number of raw tokens we need to get
# inputs_length will be the input
# targets_length will be the target
# num_noise_spans is the number of spans we have to replace
return tokens_length, inputs_length, targets_length, num_noise_spans
def random_spans_noise_mask(
inputs_length,
targets_length,
num_noise_spans,
):
"""This function is inspired from `random_spans_noise_mask <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
Spans alternate between non-noise and noise, beginning with non-noise.
Args:
inputs_length: int32 scalar
targets_length: int32 scalar
num_noise_spans: int32 scalar
Returns:
a int8 tensor with shape [num_noise_spans]
a boolean tensor with shape [length]
"""
# # pick the lengths of the noise spans and the non-noise spans
num_noise_tokens = targets_length - num_noise_spans - 1
num_nonnoise_tokens = inputs_length - num_noise_spans - 1
number_of_raw_tokens = num_noise_tokens + num_nonnoise_tokens
def _random_segmentation(num_items, num_segments):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add
up to num_items
"""
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
# TODO @thomasw21 handle random state correctly, ie synchronized across TP.
# we might not care as get_batch_pipe broadcasts data to all devices.
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]], constant_values=0)
segment_id = np.cumsum(first_in_segment)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
)
span_starts = np.concatenate([np.full((1,), 0, dtype=np.int32), np.cumsum(interleaved_span_lengths)[:-1]])
span_start_indicator = np.zeros((number_of_raw_tokens,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return span_starts, is_noise
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Finetune style dataset."""
import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
class MTFDataset(torch.utils.data.Dataset):
def __init__(
self,
name,
data_prefix,
data_impl,
skip_warmup,
documents,
):
# Params to store.
self.name = name
# Dataset.
self.input_indexed_dataset = get_indexed_dataset(data_prefix, is_input=True, data_impl=data_impl, skip_warmup=skip_warmup)
self.target_indexed_dataset = get_indexed_dataset(data_prefix, is_input=False, data_impl=data_impl, skip_warmup=skip_warmup)
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < self.input_indexed_dataset.sizes.shape[0]
assert np.max(documents) < self.target_indexed_dataset.sizes.shape[0]
assert self.input_indexed_dataset.sizes.shape[0] == self.target_indexed_dataset.sizes.shape[0]
def __len__(self):
return len(self.input_indexed_dataset)
def __getitem__(self, idx):
input_tokens = self.input_indexed_dataset.get(idx)
target_tokens = self.target_indexed_dataset.get(idx)
assert len(input_tokens) > 0
assert len(target_tokens) > 0
return {
'input_tokens': input_tokens,
'target_tokens': target_tokens,
}
def size(self, index):
return {
'input_tokens': self.input_indexed_dataset.size(index),
'target_tokens': self.target_indexed_dataset.size(index),
}
def get_indexed_dataset(data_prefix: str, is_input: bool, data_impl: str, skip_warmup: bool):
if is_input:
field = "inputs"
else:
field = "targets"
return get_indexed_dataset_(f"{data_prefix}_{field}_document", data_impl, skip_warmup)
def get_indexed_dataset_(path, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(path,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from abc import ABC
import csv
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset():
args = get_args()
tokenizer = get_tokenizer()
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
'evidence',
args.evidence_data_path,
tokenizer,
args.retriever_seq_length)
return dataset
def get_open_retrieval_batch(data_iterator):
# Items and their type.
keys = ['row_id', 'context', 'context_mask', 'context_types',
'context_pad_mask']
datatype = torch.int64
# Broadcast data.
data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
row_id = data_b['row_id'].long()
context = data_b['context'].long()
# TODO: make the context mask a binary one
context_mask = (data_b['context_mask'] < 0.5)
context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long()
return row_id, context, context_mask, context_types, context_pad_mask
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids = tokenizer.tokenize(row['title'])
context_ids = tokenizer.tokenize(row['text'])
# Appending the title of the context at front
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_ids(extended_context_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return context_ids, context_types, context_pad_mask
# noinspection DuplicatedCode
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(row_id, context_ids, context_types, context_pad_mask):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids = np.array(context_ids, dtype=np.int64)
context_types = np.array(context_types, dtype=np.int64)
context_mask = make_attention_mask(context_ids, context_ids)
sample = ({
'row_id': row_id,
'context': context_ids,
'context_mask': context_mask,
'context_types': context_types,
'context_pad_mask': context_pad_mask
})
return sample
class OpenRetrievalEvidenceDataset(ABC, Dataset):
"""Open Retrieval Evidence dataset class."""
def __init__(self, task_name, dataset_name, datapath, tokenizer,
max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
print_rank_0(datapath)
self.samples, self.id2text = self.process_samples_from_single_path(
datapath)
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
row = self.samples[idx]
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_text(row, self.tokenizer,
self.max_seq_length)
sample = build_sample(row['doc_id'],
context_ids,
context_types,
context_pad_mask)
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
total = 0
rows = []
id2text = {}
with open(filename) as tsvfile:
reader = csv.reader(tsvfile, delimiter='\t')
next(reader, None) # skip the headers
for row in reader:
# file format: doc_id, doc_text, title
doc_id = int(row[0])
text = row[1]
title = row[2]
rows.append({'doc_id': doc_id,
'text': text,
'title': title})
assert doc_id not in id2text
id2text[doc_id] = (text, title)
total += 1
if total % 100000 == 0:
print_rank_0(' > processed {} rows so far ...'.format(
total))
print_rank_0(' >> processed {} samples.'.format(len(rows)))
return rows, id2text
import os
import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
assert False, 'DistributedBatchSampler deprecated, change the implementation'
from megatron.data.samplers import DistributedBatchSampler
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
import itertools
import os
import pickle
import shutil
import numpy as np
import torch
from megatron import get_args
from megatron import mpu
def detach(tensor):
return tensor.detach().cpu().numpy()
class OpenRetreivalDataStore(object):
"""
Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def __init__(self, embedding_path=None, load_from_path=True, rank=None):
self.embed_data = dict()
if embedding_path is None:
args = get_args()
embedding_path = args.embedding_path
rank = args.rank
self.embedding_path = embedding_path
self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.embedding_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self):
return {
'embed_data': self.embed_data,
}
def clear(self):
"""
Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
"""
self.embed_data = dict()
def load_from_file(self):
"""Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
"""
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
In the case of retriever this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed in zip(row_id, block_embeds):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
def save_shard(self):
"""
Save the block data that was created this in this process
"""
if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
as writer:
pickle.dump(self.state(), writer)
def merge_shards_and_save(self):
#Combine all the shards made using save_shard
shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False
for fname in os.listdir(self.temp_dir_name):
shard_rank = int(os.path.splitext(fname)[0])
if shard_rank == self.rank:
seen_own_shard = True
continue
with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
data = pickle.load(f)
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there
# is no overlap
self.embed_data.update(data['embed_data'])
assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard
# save the consolidated shards and remove temporary directory
with open(self.embedding_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
print("Finished merging {} shards for a total of {} embeds".format(
len(shard_names), len(self.embed_data)), flush=True)
class FaissMIPSIndex(object):
"""
Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def __init__(self, embed_size, embed_data=None, use_gpu=False):
self.embed_size = embed_size
self.embed_data = embed_data
self.use_gpu = use_gpu
self.mips_index = None
self._set_mips_index()
def _set_mips_index(self):
"""
Create a Faiss Flat index with inner product as the metric
to search against
"""
try:
import faiss
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
cpu_index = faiss.IndexFlatIP(self.embed_size)
if self.use_gpu:
# create resources and config for GpuIndex
config = faiss.GpuMultipleClonerOptions()
config.shard = True
config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU", flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it
# when the FAISS structure is built
if self.embed_data is not None:
self.add_embed_data(self.embed_data)
def reset_index(self):
"""Delete existing index and create a new"""
del self.mips_index
# reset the block data so that _set_block_index will reload it as well
if self.embed_data is not None:
embed_data_path = self.embed_data.embedding_path
del self.embed_data
self.embed_data = OpenRetreivalDataStore(embed_data_path)
self._set_mips_index()
def update_index(self):
"""Delete existing index and create a new"""
del self.mips_index
# reset the block data so that _set_mips_index will reload it as well
if self.embed_data is not None:
self.embed_data.load_from_file()
self._set_mips_index()
def add_embed_data(self, all_embed_data):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math
# internally is done with float16.
embeds_arr = np.float32(np.array(block_embeds))
indices_arr = np.array(block_indices)
# we no longer need the embedding data since it's in the index now
all_embed_data.clear()
self.mips_index.add_with_ids(embeds_arr, indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim]
array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.mips_index.search_and_reconstruct(\
query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.mips_index.search(query_embeds, top_k)
return distances, block_indices
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5 Style dataset."""
import collections
import numpy as np
import torch
from megatron import get_tokenizer
from megatron.data.dataset_utils import (
create_masked_lm_predictions,
get_samples_mapping
)
class T5Dataset(torch.utils.data.Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, max_seq_length_dec,
short_seq_prob, seed):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.max_seq_length_dec = max_seq_length_dec
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length - 2, # account for added tokens
short_seq_prob,
self.seed,
self.name,
False)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
self.bos_id = tokenizer.bos_token_id
self.eos_id = tokenizer.eos_token_id
self.sentinel_tokens = tokenizer.additional_special_tokens_ids
assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx]
sample = []
for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index])
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.max_seq_length_dec,
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng,
self.bos_id, self.eos_id,
self.sentinel_tokens)
def build_training_sample(sample, target_seq_length,
max_seq_length, max_seq_length_dec,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng, bos_id=None,
eos_id=None, sentinel_tokens=None):
"""Build training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
bos_id: start of decoder example id
eos_id: end of generation id
sentinel_tokens: unique value to be substituted for every replaced span
"""
assert target_seq_length <= max_seq_length
# flatten sentences into one list
tokens = [token for sentence in sample for token in sentence]
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = len(tokens) > max_num_tokens
tokens = tokens[:max_num_tokens]
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _, masked_spans) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng,
max_ngrams=10, geometric_dist=True, masking_style="t5")
# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask \
= pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id, max_seq_length,
max_seq_length_dec, masked_spans,
bos_id, eos_id, sentinel_tokens)
train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
return train_sample
def pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id,
max_seq_length, max_seq_length_dec,
masked_spans=None, bos_id=None,
eos_id=None, sentinel_tokens=None):
"""Pad sequences and convert them to numpy."""
sentinel_tokens = collections.deque(sentinel_tokens)
t5_input = []
(t5_decoder_in, t5_decoder_out) = ([bos_id], [])
(start_index, end_index) = (0, None)
for span in masked_spans:
flag = sentinel_tokens.popleft()
# Append the same tokens in decoder input and output
t5_decoder_in.append(flag)
t5_decoder_in.extend(span.label)
t5_decoder_out.append(flag)
t5_decoder_out.extend(span.label)
end_index = span.index[0]
t5_input.extend(tokens[start_index: end_index])
t5_input.append(flag)
# the next start index is the token after the last span token
start_index = span.index[-1] + 1
# Add <eos> token to the t5_decoder_out
t5_decoder_out.append(eos_id)
# Add the remaining tokens to the t5 input
t5_input.extend(tokens[start_index:])
# assert (len(t5_input) - len(masked_spans)) + \
# (len(t5_decoder_in) - (len(masked_spans) + 1)) == len(tokens)
# Some checks.
# Encoder-side padding mask.
num_tokens = len(t5_input)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(masked_positions) == len(masked_labels)
# Tokens..
filler = [pad_id] * padding_length
tokens_enc = np.array(t5_input + filler, dtype=np.int64)
# Decoder-side padding mask.
num_tokens_dec = len(t5_decoder_in)
padding_length_dec = max_seq_length_dec - num_tokens_dec
assert padding_length_dec >= 0
filler_dec = [pad_id] * padding_length_dec
tokens_dec_in = np.array(t5_decoder_in + filler_dec, dtype=np.int64)
# Create attention masks
enc_mask = make_attention_mask(tokens_enc, tokens_enc)
enc_dec_mask = make_attention_mask(tokens_dec_in, tokens_enc)
dec_mask = make_attention_mask(tokens_dec_in, tokens_dec_in)
dec_mask = dec_mask * make_history_mask(tokens_dec_in)
# Labels mask.
labels = t5_decoder_out + ([-1] * padding_length_dec)
labels = np.array(labels, dtype=np.int64)
# Loss mask
loss_mask = ([1] * num_tokens_dec) + ([0] * padding_length_dec)
loss_mask = np.array(loss_mask, dtype=np.int64)
return tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def make_attention_mask_3d(source_block, target_block):
"""
Returns a 3-dimensional (3-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[:, None, :] >= 1) * (source_block[:, :, None] >= 1)
# (batch, source_length, target_length)
# mask = mask.astype(np.int64)
return mask
def make_history_mask(block):
length = block.shape[0]
arange = np.arange(length)
history_mask = (arange[None, ] <= arange[:, None])
history_mask = history_mask.astype(np.int64)
return history_mask
def make_history_mask_3d(block):
batch, length = block.shape
arange = torch.arange(length, device=block.device)
history_mask = (arange[None, ] <= arange[:, None])[None, ]
history_mask = history_mask.expand(batch, length, length)
return history_mask
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse
import os
import sys
import torch
script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../"))
def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
print(len(ds.doc_idx))
print(len(ds))
print(ds.doc_idx[-1])
if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds)))
if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx) - 1
for i in range(args.count):
start = ds.doc_idx[i]
end = ds.doc_idx[i + 1]
ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids:
assert len(s) > 0
l = s.data.tolist()
text = tokenizer.detokenize(l)
print(text)
print("---")
def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = build_tokenizer(args)
size = ds.sizes[0]
print(f"size: {size}")
full = ds.get(0)
print(full)
# print(tokenizer.detokenize(full.data.tolist()))
print("---")
end = ds.get(0, offset=size - 10)
print(end)
# print(tokenizer.detokenize(end.data.tolist()))
start = ds.get(0, length=10)
print(start)
# print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8)
print(part)
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for')
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
help='probability of masking tokens')
parser.add_argument('--seq-length', type=int, default=512,
help='maximum sequence length')
parser.add_argument('--short-seq-prob', type=float, default=0.1,
help='probability of creating a short sequence')
parser.add_argument('--seed', type=int, default=1234,
help='random seed')
args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
# test_albert_dataset(args)
test_indexed_dataset_get(args)
if __name__ == "__main__":
main()
#!/bin/bash
IMPL=cached
python ../preprocess_data.py \
--input test_samples.json \
--vocab vocab.txt \
--dataset-impl ${IMPL} \
--output-prefix test_samples_${IMPL} \
--workers 1 \
--log-interval 2
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy
def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
# training dataset
train_data_path = os.path.join(data_path[0], "train")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
process = [
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
]
if color_jitter:
process += [
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
]
fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train
)
# validation dataset
val_data_path = os.path.join(data_path[0], "val")
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
val_data = datasets.ImageFolder(
root=val_data_path, transform=transform_val
)
# test dataset
test_data_path = os.path.join(data_path[0], "test")
transform_test = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
fp16_t
]
)
test_data = datasets.ImageFolder(
root=test_data_path, transform=transform_test
)
return train_data, val_data, test_data
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2 # Overrides `attention_mask` to be a lower triangular matrix
prefix = 3
custom = 4 # Forces one to pass an `attention_mask` that's 1 if we need to mask. Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length]
class PositionEmbeddingType(enum.Enum):
rotary = 1
absolute = 2
alibi = 3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment