Commit 690291dd authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

fixed couple of bugs in helpers

parent 1a1cfcff
...@@ -12,6 +12,7 @@ from .dataset_utils import build_training_sample ...@@ -12,6 +12,7 @@ from .dataset_utils import build_training_sample
from . import helpers from . import helpers
from megatron.data import FullBertTokenizer, indexed_dataset from megatron.data import FullBertTokenizer, indexed_dataset
from megatron.utils import print_rank_0
class AlbertDataset(Dataset): class AlbertDataset(Dataset):
...@@ -31,11 +32,19 @@ class AlbertDataset(Dataset): ...@@ -31,11 +32,19 @@ class AlbertDataset(Dataset):
# Build the samples mapping. # Build the samples mapping.
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples or num_epochs") raise ValueError("Need to specify either max_num_samples "
num_epochs = int(max_num_samples / len(indexed_dataset)) + 1 "or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples: if not max_num_samples:
max_num_samples = len(indexed_dataset) * num_epochs max_num_samples = np.iinfo(np.int64).max - 1
print(f"Building the sample map for {num_epochs} epochs or {max_num_samples} samples.")
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank()==0
start_time = time.time()
self.samples_mapping = helpers.build_mapping( self.samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
...@@ -43,7 +52,14 @@ class AlbertDataset(Dataset): ...@@ -43,7 +52,14 @@ class AlbertDataset(Dataset):
max_num_samples, max_num_samples,
self.max_seq_length-3, # account for added tokens self.max_seq_length-3, # account for added tokens
short_seq_prob, short_seq_prob,
self.seed) self.seed,
verbose)
# Make sure all the ranks have built the mapping
torch.distributed.barrier()
print_rank_0('> elasped time to build samples mapping (seconds): '
'{:2f}'.format(time.time() - start_time))
exit()
# Vocab stuff. # Vocab stuff.
self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_list = list(tokenizer.inv_vocab.keys())
...@@ -59,11 +75,12 @@ class AlbertDataset(Dataset): ...@@ -59,11 +75,12 @@ class AlbertDataset(Dataset):
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, skip_warmup=False): max_seq_length, short_seq_prob, seed, skip_warmup=False):
tokenizer = FullBertTokenizer(vocab, do_lower_case=True) tokenizer = FullBertTokenizer(vocab, do_lower_case=True)
print("> Reading dataset index") print_rank_0("> Reading dataset index ...")
idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl, skip_warmup) idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl,
print("> Finished creating indexed dataset") skip_warmup)
return cls(idx_ds, tokenizer, num_epochs, max_num_samples, masked_lm_prob, print_rank_0("> Finished creating indexed dataset")
max_seq_length, short_seq_prob, seed) return cls(idx_ds, tokenizer, num_epochs, max_num_samples,
masked_lm_prob, max_seq_length, short_seq_prob, seed)
def num_tokens(self): def num_tokens(self):
return self.tokenizer.vocab_size() return self.tokenizer.vocab_size()
......
/* Helper methods for fast index mapping builds */
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
...@@ -6,46 +7,61 @@ ...@@ -6,46 +7,61 @@
#include <stdexcept> #include <stdexcept>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <random>
namespace py = pybind11; namespace py = pybind11;
using namespace std; using namespace std;
inline uint32_t get_sample_len(const int short_seq_ratio, inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const uint32_t max_length) { const int32_t max_length,
std::mt19937& rand32_gen) {
/* Training sample length. */ /* Training sample length. */
const auto random_number = rand(); const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
} }
return max_length; return max_length;
} }
template<typename DocIdx> template<typename DocIdx>
py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<uint16_t>& sizes_, const py::array_t<int32_t>& sizes_,
const int num_epochs, const int32_t num_epochs,
const uint64_t max_num_samples, const uint64_t max_num_samples,
const int max_seq_length, const int32_t max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed) { const int32_t seed,
const bool verbose) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
if (verbose) {
cout << " > using " << docs_.shape(0) - 1 <<
" documents with " << sizes_.shape(0) << " sentences ..." <<
endl << std::flush;
}
cout << "> building dataset mapping for " << docs_.shape(0) - 1\ // Consistency checks.
<< " documents with " << sizes_.shape(0) << " sentences ..." assert(num_epochs > 0);
<< std::flush << endl; assert(max_seq_length > 1);
assert(short_seq_prob > 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// For efficiency, convert probability to ratio. // For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int>(round(1.0 / short_seq_prob)); const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
// Remove bound checks. // Remove bound checks.
auto docs = docs_.unchecked<1>(); auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
// Check for consistency.
if (docs[docs.shape(0) - 1] != sizes.shape(0)) { if (docs[docs.shape(0) - 1] != sizes.shape(0)) {
cout << "document values is not consistent with length of sizes: " << cout << "document values is not consistent with length of sizes: " <<
docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl; docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl;
throw(-1); throw std::length_error("docs and sizes");
} }
// Mapping and it's length (1D). // Mapping and it's length (1D).
...@@ -55,36 +71,39 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, ...@@ -55,36 +71,39 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
// Perform two iterations, in the first iteration get the size // Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map. // and allocate memory and in the second iteration populate the map.
bool second = false; bool second = false;
for (int iteration=0; iteration < 2; ++iteration) { for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the seed so both iterations produce the same results. // Set the seed so both iterations produce the same results.
srand(seed); std::mt19937 rand32_gen(seed);
// Set the flag on second iteration. // Set the flag on second iteration.
second = iteration == 1; second = (iteration == 1);
// Counters: // Counters:
uint32_t empty_docs = 0; uint64_t empty_docs = 0;
uint32_t one_sent_docs = 0; uint64_t one_sent_docs = 0;
// Current map index. // Current map index.
uint64_t map_index = 0; uint64_t map_index = 0;
// For each epoch: // For each epoch:
for (int epoch=0; epoch < num_epochs; ++epoch) { for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples && !second) { if (map_index >= max_num_samples) {
cout << " > reached " << max_num_samples << " samples after " if (verbose && (!second)) {
<< epoch << " epochs ..." << std::flush << endl; cout << " > reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break; break;
} }
// For each document: // For each document:
for (int doc=0; doc < (docs.shape(0) - 1); ++doc) { for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last). // Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc]; const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1]; const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the start index. // At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first; auto prev_start_index = sent_index_first;
// Remaining documents. // Remaining documents.
...@@ -93,13 +112,10 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, ...@@ -93,13 +112,10 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
// Some bookkeeping // Some bookkeeping
if ((epoch == 0) && (!second)) { if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) { if (num_remain_sent == 0) {
cout << "***WARNING*** document " << doc << " is empty" << endl; ++empty_docs;
empty_docs += 1;
} }
if (num_remain_sent == 1) { if (num_remain_sent == 1) {
// cout << "***WARNING*** document " << doc << ++one_sent_docs;
// " has one sentence" << endl;
one_sent_docs += 1;
} }
} }
...@@ -107,110 +123,154 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_, ...@@ -107,110 +123,154 @@ py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
if (num_remain_sent > 1) { if (num_remain_sent > 1) {
// Set values. // Set values.
auto size = uint32_t{0}; auto seq_len = int32_t{0};
auto num_sent = uint32_t{0}; auto num_sent = int32_t{0};
auto seq_len = get_sample_len(short_seq_ratio, max_seq_length); auto target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
// Loop through sentences. // Loop through sentences.
for (auto sent_index=sent_index_first; for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) { sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences. // Add the size and number of sentences.
size += sizes[sent_index]; seq_len += sizes[sent_index];
num_sent += 1; ++num_sent;
num_remain_sent -= 1; --num_remain_sent;
// If we have reached the target length. // If we have reached the target length.
// and if not only one sentence is left in the document. // and if not only one sentence is left in the document.
// and if we have at least two sentneces. // and if we have at least two sentneces.
// and if we have reached end of the document. // and if we have reached end of the document.
if (((size >= seq_len) && (num_remain_sent > 1) && if (((seq_len >= target_seq_len) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) { // Check for overflow.
const auto map_index_0 = 3 * map_index; if ((3 * map_index + 2) >
maps[map_index_0] = prev_start_index; std::numeric_limits<int64_t>::max()) {
maps[map_index_0 + 1] = sent_index + 1; cout << "number of samples exceeded maximum "
maps[map_index_0 + 2] = seq_len; << "allowed by type int64: "
} << std::numeric_limits<int64_t>::max()
<< endl;
// Update indices / counters. throw std::overflow_error("Number of samples");
// check for overflow }
if (map_index == std::numeric_limits<DocIdx>::max()) {
cout << "number of samples exceeded maximum allowed by type: " // Populate the map.
<< std::numeric_limits<DocIdx>::max() << endl; if (second) {
throw std::overflow_error("Number of samples"); const auto map_index_0 = 3 * map_index;
} maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
map_index += 1; maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
prev_start_index = sent_index + 1; maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
seq_len = get_sample_len(short_seq_ratio, max_seq_length); }
size = 0;
num_sent = 0; // Update indices / counters.
} ++map_index;
} prev_start_index = sent_index + 1;
target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) { } // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) { } // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) { } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) { if (!second) {
cout << " number of samples: " << if (verbose) {
map_index << endl; cout << " > number of empty documents: " << empty_docs <<
cout << " number of empty documents: " << endl << std::flush;
empty_docs << endl; cout << " > number of documents with one sentence: " <<
cout << " number of documents with one sentence: " << one_sent_docs << endl << std::flush;
one_sent_docs << endl; cout << " > will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index]; maps = new DocIdx[3*map_index];
num_samples = map_index; num_samples = static_cast<int64_t>(map_index);
} }
} // for (int iteration=0; iteration < 2; ++iteration) { } // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle. // Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) { for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = rand() % (i + 1); const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i; const auto i0 = 3 * i;
const auto j0 = 3 * j; const auto j0 = 3 * j;
// Swap values. // Swap values.
swap(maps[i0], maps[j0]); swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]); swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]); swap(maps[i0 + 2], maps[j0 + 2]);
} }
cout << " > done building the mapping." << endl; if (verbose) {
cout << "> done building the mapping." << endl;
}
// Method to deallocate memory. // Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) { py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_); DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
cout << "freeing memory for the dataset mapping" << endl; delete[] mem;
delete[] mem;
}); });
// Return the numpy array. // Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*4, 4}, // C-style contiguous strides {3*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer maps, // the data pointer
free_when_done); // numpy array references free_when_done); // numpy array references
} }
py::array build_mapping(const py::array& docs_,
const py::array& sizes_, py::array build_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const int num_epochs, const int num_epochs,
const uint64_t max_num_samples, const uint64_t max_num_samples,
const int max_seq_length, const int max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed) { const int seed,
const bool verbose) {
if (verbose) {
cout << "> building sample map using: " << endl << std::flush;
cout << " number of epochs: " << num_epochs << endl
<< std::flush;
cout << " maximum number of samples: " << max_num_samples << endl
<< std::flush;
cout << " maximum sequence length: " << max_seq_length << endl
<< std::flush;
cout << " short sequence probability: " << short_seq_prob << endl
<< std::flush;
cout << " seed: " << seed << endl
<< std::flush;
}
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, max_num_samples, if (verbose) {
max_seq_length, short_seq_prob, seed); cout << " > using uint64 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
} else { } else {
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, max_num_samples, if (verbose) {
max_seq_length, short_seq_prob, seed); cout << " > using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
} }
} }
PYBIND11_MODULE(helpers, m) { PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping); m.def("build_mapping", &build_mapping);
} }
...@@ -18,6 +18,7 @@ from itertools import accumulate ...@@ -18,6 +18,7 @@ from itertools import accumulate
import numpy as np import numpy as np
import torch import torch
from megatron.utils import print_rank_0
def __best_fitting_dtype(vocab_size=None): def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500: if vocab_size is not None and vocab_size < 65500:
...@@ -317,7 +318,7 @@ class IndexedDatasetBuilder(object): ...@@ -317,7 +318,7 @@ class IndexedDatasetBuilder(object):
def _warmup_mmap_file(path): def _warmup_mmap_file(path):
with open(path, 'rb') as stream: with open(path, 'rb') as stream:
while stream.read(1 * 1024 * 1024): while stream.read(100 * 1024 * 1024):
pass pass
...@@ -390,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -390,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
offset = stream.tell() offset = stream.tell()
if not skip_warmup: if not skip_warmup:
print("> Warming up index mmap file...") print_rank_0("> Warming up index mmap file...")
_warmup_mmap_file(path) _warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print("> Reading sizes...") print_rank_0("> Reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
print("> Reading pointers...") print_rank_0("> Reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes) offset=offset + self._sizes.nbytes)
print("> Reading document index...") print_rank_0("> Reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes) offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self): def __del__(self):
...@@ -446,13 +447,13 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -446,13 +447,13 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._index = self.Index(index_file_path(self._path), skip_warmup) self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup: if not skip_warmup:
print("> Warming up data mmap file...") print_rank_0("> Warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path)) _warmup_mmap_file(data_file_path(self._path))
print("> Creating numpy buffer of mmap...") print_rank_0("> Creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print("> Creating memory view of numpy buffer...") print_rank_0("> Creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print("> Done") print_rank_0("> Done")
def __del__(self): def __del__(self):
self._bin_buffer_mmap._mmap.close() self._bin_buffer_mmap._mmap.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment