Commit 69d8ff1c authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging_gpt2_dataloader' into 'staging'

GPT2 dataloader using indexed dataset

See merge request ADLR/megatron-lm!45
parents 7890681a 0afe1fde
# ===========
# base images
# ===========
FROM nvcr.io/nvidia/pytorch:19.09-py3
# ===============
# system packages
# ===============
RUN apt-get update && apt-get install -y \
bash-completion \
emacs \
git \
graphviz \
htop \
libopenexr-dev \
rsync \
wget \
&& rm -rf /var/lib/apt/lists/*
# ============
# pip packages
# ============
RUN pip install --upgrade pip && \
pip install --upgrade setuptools
COPY requirements.txt /tmp/
RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt
boto3
google-cloud-language
inflect
nltk
numpy
pandas
requests
sentencepiece
tensorflow
tqdm
...@@ -24,7 +24,6 @@ from torch.utils.data import Dataset ...@@ -24,7 +24,6 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data import helpers
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -249,6 +248,7 @@ def get_samples_mapping_(indexed_dataset, ...@@ -249,6 +248,7 @@ def get_samples_mapping_(indexed_dataset,
start_time = time.time() start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format( print_rank_0(' > building sapmles index mapping for {} ...'.format(
name)) name))
from megatron.data import helpers
samples_mapping = helpers.build_mapping( samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
......
...@@ -13,124 +13,305 @@ ...@@ -13,124 +13,305 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 dataset.""" """GPT2 style dataset."""
import json
import os import os
import numpy as np import time
import numpy as np
import torch import torch
from torch.utils.data import Dataset
from megatron import print_rank_0
from megatron import mpu
class GPT2Dataset(Dataset): from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def __init__(self, data_path, sizes_filename, seq_length,
initial_seed, max_epochs=100):
# Input parameters. def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
self.data_path = data_path train_valid_test_num_samples,
self.sizes_filename = sizes_filename seq_length, seed, skip_warmup):
self.seq_length = seq_length """Build train, valid, and test datasets."""
self.initial_seed = initial_seed
self.max_epochs = max_epochs # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
# Shard stuff. data_impl,
# Dictionary from shard nameto its size (number of element). skip_warmup)
self.master_shard_size_dict = None
# Dictionary from shard name to modified size so it is total_num_of_documents = indexed_dataset.sizes.shape[0]
# divisible by self.seq_length. splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
self.shard_size_dict = None
# Long array (self.max_epochs * num-shards) populated # Print stats about the splits.
# randomly with shard names. print_rank_0(' > dataset split:')
self.shards_name = None def print_split_stats(name, index):
# Start index of the data for a shard. print_rank_0(' {}:'.format(name))
self.shards_start_index = None print_rank_0(' document indices in [{}, {}) total of {} '
self.build_shard_mappings_() 'documents'.format(splits[index], splits[index + 1],
self.data_length = self.shards_start_index[-1] splits[index + 1] - splits[index]))
print_split_stats('train', 0)
# Data. print_split_stats('validation', 1)
self.shards_data = [None]*self.shards_name.size print_split_stats('test', 2)
self.shards_sample_index = [None]*self.shards_name.size
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index+1],
step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self): def __len__(self):
return self.data_length # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx): def __getitem__(self, idx):
# Find which shard we need. # Get the shuffled index.
shard_index = np.searchsorted(self.shards_start_index, idx = self.shuffle_idx[idx]
idx, side='right') - 1 # Start and end documents and offsets.
# data index in the shard. doc_index_f = self.sample_idx[idx][0]
data_idx = idx - self.shards_start_index[shard_index] doc_index_l = self.sample_idx[idx+1][0]
# Load the shard if it is not in memory. offset_f = self.sample_idx[idx][1]
if self.shards_data[shard_index] is None: offset_l = self.sample_idx[idx+1][1]
print('global rank {} is building data for shard index {} ...'. # If we are within the same document, just extract the chunk.
format(torch.distributed.get_rank(), shard_index)) if doc_index_f == doc_index_l:
self.build_dataset_(shard_index) sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
#assert self.shards_data[shard_index] is not None offset=offset_f,
# Start index. length=offset_l - offset_f + 1)
start_index = self.shards_sample_index[shard_index][data_idx] else:
# Add one for label shift. # Otherwise, get the rest of the initial document.
end_index = start_index + self.seq_length + 1 sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
data = self.shards_data[shard_index][start_index:end_index] offset=offset_f)]
return {'text': np.array(data, dtype=np.int64)} # Loop over all in between documents and add the entire document.
for i in range(doc_index_f+1, doc_index_l):
def build_dataset_(self, shard_index): sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# Garbage collect so we don't use a lot of memory. # And finally add the relevant portion of last document.
# Leave the last one in case other threads have not catche up yet. sample_list.append(self.indexed_dataset.get(
#for i in range(shard_index - 1): self.doc_idx[doc_index_l],
for i in range(shard_index): length=offset_l+1))
self.shards_data[i] = None sample = np.concatenate(sample_list)
self.shards_sample_index[i] = None
# Read the shard. return {'text': np.array(sample, dtype=np.int64)}
filename = os.path.join(self.data_path, self.shards_name[shard_index])
print('loading {}'.format(filename))
data = np.load(filename, allow_pickle=True)
# Shuffle the data def _build_index_mappings(name, data_prefix, documents, sizes,
rng = np.random.RandomState(self.initial_seed + shard_index) num_samples, seq_length, seed):
rng.shuffle(data) """Build doc-idx, sample-idx, and shuffle-idx.
# Flatten. doc-idx: is an array (ordered) of documents to be used in training.
data = np.hstack(data) sample-idx: is the start document index and document offset for each
size = (data.shape[0] - 1) // self.seq_length training sample.
last_index = size * self.seq_length + 1 shuffle-idx: maps the sample index into a random index into sample-idx.
data = data[0:last_index] """
self.shards_data[shard_index] = data # Number of tokens in each epoch and number of required epochs.
indices = np.arange(size) * self.seq_length tokens_per_epoch = _num_tokens(documents, sizes)
rng.shuffle(indices) num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
self.shards_sample_index[shard_index] = indices # rng state
np_rng = np.random.RandomState(seed=seed)
def build_shard_mappings_(self):
# Load the sizes file. # Filename of the index mappings.
sizes_filename = os.path.join(self.data_path, self.sizes_filename) _filename = data_prefix
if torch.distributed.get_rank() == 0: _filename += '_{}_indexmap'.format(name)
print(' > loading sizes from {}'.format(sizes_filename)) _filename += '_{}ns'.format(num_samples)
with open(sizes_filename, 'r') as f: _filename += '_{}sl'.format(seq_length)
self.master_shard_size_dict = json.load(f) _filename += '_{}s'.format(seed)
if torch.distributed.get_rank() == 0: doc_idx_filename = _filename + '_doc_idx.npy'
print(' found {} shards'.format(len(self.master_shard_size_dict))) sample_idx_filename = _filename + '_sample_idx.npy'
# Adjust sizes to be a multiple of seq_length. shuffle_idx_filename = _filename + '_shuffle_idx.npy'
self.shard_size_dict = self.master_shard_size_dict.copy()
total_samples = 0 # Build the indexed mapping if not exist.
for shard in self.shard_size_dict: if torch.distributed.get_rank() == 0:
size = self.shard_size_dict[shard] if (not os.path.isfile(doc_idx_filename)) or \
size = ((size - 1) // self.seq_length) * self.seq_length (not os.path.isfile(sample_idx_filename)) or \
total_samples += size // self.seq_length (not os.path.isfile(shuffle_idx_filename)):
self.shard_size_dict[shard] = size
if torch.distributed.get_rank() == 0: print_rank_0(' > WARNING: could not find index map files, building '
print(' found {} samples in the dataset'.format(total_samples)) 'the indices on rank 0 ...')
# Build a list of shards. # doc-idx.
shards_ = np.sort(np.array(list(self.shard_size_dict.keys()))) start_time = time.time()
rng = np.random.RandomState(self.initial_seed) doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
self.shards_name = np.copy(shards_) np.save(doc_idx_filename, doc_idx, allow_pickle=True)
rng.shuffle(self.shards_name) print_rank_0(' > elasped time to build and save doc-idx mapping '
for i in range(1, self.max_epochs): '(seconds): {:4f}'.format(time.time() - start_time))
shards_c = np.copy(shards_) # sample-idx.
rng.shuffle(shards_c) start_time = time.time()
self.shards_name = np.append(self.shards_name, shards_c) # Use C++ implementation for speed.
# Build the global indexing. from megatron.data import helpers
self.shards_start_index = np.zeros(self.shards_name.size, dtype=np.int) assert doc_idx.dtype == np.int32
self.shards_start_index[0] = 0 assert sizes.dtype == np.int32
for i in range(1, self.shards_name.size): sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
shard = str(self.shards_name[i-1]) num_epochs, tokens_per_epoch)
size = self.shard_size_dict[shard] #sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
self.shards_start_index[i] = self.shards_start_index[i-1] + \ # num_epochs, tokens_per_epoch)
size // self.seq_length np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True)
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True)
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
def _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 1] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(size, np_rng):
"""Build the range [0, size) and shuffle."""
dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx)
return shuffle_idx
...@@ -33,6 +33,95 @@ using namespace std; ...@@ -33,6 +33,95 @@ using namespace std;
const int32_t LONG_SENTENCE_LEN = 512; const int32_t LONG_SENTENCE_LEN = 512;
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio, inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
...@@ -307,4 +396,5 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -307,4 +396,5 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
PYBIND11_MODULE(helpers, m) { PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping); m.def("build_mapping", &build_mapping);
m.def("build_sample_idx", &build_sample_idx);
} }
...@@ -42,6 +42,7 @@ def infer_dataset_impl(path): ...@@ -42,6 +42,7 @@ def infer_dataset_impl(path):
else: else:
return None return None
else: else:
print(f"Dataset path does not exist: {path}")
return None return None
...@@ -61,6 +62,7 @@ def make_dataset(path, impl, skip_warmup=False): ...@@ -61,6 +62,7 @@ def make_dataset(path, impl, skip_warmup=False):
return IndexedCachedDataset(path) return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path): elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup) return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None return None
...@@ -466,9 +468,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -466,9 +468,8 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
ptr, size = self._index[idx] ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
if self._index.dtype != np.int64: count=size, offset=ptr)
np_array = np_array.astype(np.int64)
return np_array return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
...@@ -478,10 +479,25 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -478,10 +479,25 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self._index._sizes[idx] sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes)) offsets = list(accumulate(sizes))
total_size = sum(sizes) total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1]) sents = np.split(np_array, offsets[:-1])
return sents return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property @property
def sizes(self): def sizes(self):
return self._index.sizes return self._index.sizes
......
import argparse
import json
import multiprocessing
import nltk
import sys
import time
import torch
from bert_tokenization import FullTokenizer
import indexed_dataset
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True)
spliter = nltk.load("tokenizers/punkt/english.pickle")
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = spliter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = spliter
def encode(self, json_line):
text = json.loads(json_line)[self.args.json_key]
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
tokens = Encoder.tokenizer.tokenize(sentence)
ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
if len(ids) > 0:
doc_ids.append(ids)
return doc_ids, len(json_line)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='Path to input JSON')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--json-key', type=str, default='text',
help='Key to extract from json')
parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix')
parser.add_argument('--workers', type=int, default=20,
help='Number of worker processes to launch')
parser.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
parser.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences.')
parser.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
args = parser.parse_args()
args.keep_empty = False
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
nltk.download("punkt", quiet=True)
encoder = Encoder(args)
tokenizer = FullTokenizer(args.vocab, do_lower_case=True)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
print(f"Vocab size: {tokenizer.vocab_size()}")
output_bin_file = "{}.bin".format(args.output_prefix)
output_idx_file = "{}.idx".format(args.output_prefix)
builder = indexed_dataset.make_builder(output_bin_file,
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size())
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for sentence in doc:
#print(sentence)
#print(tokenizer.convert_ids_to_tokens(sentence))
builder.add_item(torch.IntTensor(sentence))
builder.end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
builder.finalize(output_idx_file)
if __name__ == '__main__':
main()
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
import argparse import argparse
import os import os
import sys import sys
...@@ -7,52 +11,90 @@ import torch ...@@ -7,52 +11,90 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../")) sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.data import indexed_dataset, FullBertTokenizer, AlbertDataset from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
def test_indexed_dataset(args): def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) tokenizer = build_tokenizer(args)
print(len(ds.doc_idx)) print(len(ds.doc_idx))
print(len(ds)) print(len(ds))
print(ds.doc_idx[-1]) print(ds.doc_idx[-1])
if ds.supports_prefetch: if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small) # just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds))) ds.prefetch(range(len(ds)))
for i in range(len(ds.doc_idx)-1): if args.count > len(ds.doc_idx)-1:
args.count = len(ds.doc_idx)-1
for i in range(args.count):
start = ds.doc_idx[i] start = ds.doc_idx[i]
end = ds.doc_idx[i+1] end = ds.doc_idx[i+1]
ids = ds[start:end] ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids: for s in ids:
assert len(s) > 0 assert len(s) > 0
l = s.data.tolist() l = s.data.tolist()
tokens = tokenizer.convert_ids_to_tokens(l) text = tokenizer.detokenize(l)
for t in tokens: print(text)
if '\n' in t: print("---")
print("Newline in string!")
print(i) def test_indexed_dataset_get(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
def test_albert_dataset(args): tokenizer = build_tokenizer(args)
# tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) size = ds.sizes[0]
# idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) print(f"size: {size}")
# ds = AlbertDataset(idataset, tokenizer) full = ds.get(0)
ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, print(full)
args.epochs, args.max_num_samples, #print(tokenizer.detokenize(full.data.tolist()))
args.masked_lm_prob, args.seq_length, print("---")
args.short_seq_prob, args.seed) end = ds.get(0, offset=size-10)
truncated = 0 print(end)
total = 0 #print(tokenizer.detokenize(end.data.tolist()))
for s in ds:
ids = s['text'] start = ds.get(0, length=10)
tokens = ds.tokenizer.convert_ids_to_tokens(ids) print(start)
print(tokens) #print(tokenizer.detokenize(start.data.tolist()))
exit()
part = ds.get(0, offset=2, length=8)
print(part)
#print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--dataset-impl', type=str, default='infer', parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer']) choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5, parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for') help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None, parser.add_argument('--max-num-samples', type=int, default=None,
...@@ -66,12 +108,15 @@ def main(): ...@@ -66,12 +108,15 @@ def main():
parser.add_argument('--seed', type=int, default=1234, parser.add_argument('--seed', type=int, default=1234,
help='random seed') help='random seed')
args = parser.parse_args() args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
test_albert_dataset(args) # test_albert_dataset(args)
# test_indexed_dataset(args) test_indexed_dataset_get(args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment