Commit 0601702a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

zero worker seems to be working

parent f6a6811f
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import mpu
from megatron.data import helpers from megatron.data import helpers
from megatron.data import FullBertTokenizer from megatron.data import FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
...@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset ...@@ -15,22 +16,97 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
splits_string, train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup):
# Tokenizer is the same
tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
tokenizer.vocab_size()))
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
dataset = AlbertDataset(
name=name,
indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
masked_lm_prob=masked_lm_prob,
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
class AlbertDataset(Dataset): class AlbertDataset(Dataset):
def __init__(self, vocab_file, data_prefix, data_impl, skip_warmup, def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, max_seq_length, num_epochs, max_num_samples, masked_lm_prob,
short_seq_prob, seed): max_seq_length, short_seq_prob, seed):
# Params to store. # Params to store.
self.name = name
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
# Indexed dataset. # Tokenizer and dataset.
self.indexed_dataset = get_indexed_dataset_(data_prefix, self.tokenizer = tokenizer
data_impl, self.indexed_dataset = indexed_dataset
skip_warmup)
# Build the samples mapping. # Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset, self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
...@@ -39,7 +115,8 @@ class AlbertDataset(Dataset): ...@@ -39,7 +115,8 @@ class AlbertDataset(Dataset):
max_num_samples, max_num_samples,
self.max_seq_length, self.max_seq_length,
short_seq_prob, short_seq_prob,
self.seed) self.seed,
self.name)
# Vocab stuff. # Vocab stuff.
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
...@@ -48,7 +125,6 @@ class AlbertDataset(Dataset): ...@@ -48,7 +125,6 @@ class AlbertDataset(Dataset):
self.sep_id = self.tokenizer.vocab['[SEP]'] self.sep_id = self.tokenizer.vocab['[SEP]']
self.mask_id = self.tokenizer.vocab['[MASK]'] self.mask_id = self.tokenizer.vocab['[MASK]']
self.pad_id = self.tokenizer.vocab['[PAD]'] self.pad_id = self.tokenizer.vocab['[PAD]']
exit()
def num_tokens(self): def num_tokens(self):
...@@ -68,9 +144,11 @@ class AlbertDataset(Dataset): ...@@ -68,9 +144,11 @@ class AlbertDataset(Dataset):
sample = [] sample = []
for index in range(start_index, end_index): for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index]) sample.append(self.indexed_dataset[index])
'''
for s in sample: for s in sample:
if len(s) > 1000: if len(s) > 1000:
print(self.tokenizer.convert_ids_to_tokens(s)) print(self.tokenizer.convert_ids_to_tokens(s))
'''
return build_training_sample(sample, seq_length, return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding self.max_seq_length, # needed for padding
self.vocab_id_list, self.vocab_id_list,
...@@ -80,25 +158,63 @@ class AlbertDataset(Dataset): ...@@ -80,25 +158,63 @@ class AlbertDataset(Dataset):
self.masked_lm_prob, rng) self.masked_lm_prob, rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time() start_time = time.time()
print_rank_0("> Reading dataset index ...")
indexed_dataset = make_indexed_dataset(data_prefix, indexed_dataset = make_indexed_dataset(data_prefix,
data_impl, data_impl,
skip_warmup) skip_warmup)
print_rank_0("> Finished creating indexed dataset in {:4f} " assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
"seconds".format(time.time() - start_time)) print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split/splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
data_prefix, data_prefix,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
max_seq_length, max_seq_length,
short_seq_prob, short_seq_prob,
seed): seed,
name):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -109,9 +225,11 @@ def get_samples_mapping_(indexed_dataset, ...@@ -109,9 +225,11 @@ def get_samples_mapping_(indexed_dataset,
# Filename of the index mapping # Filename of the index mapping
indexmap_filename = data_prefix indexmap_filename = data_prefix
indexmap_filename += '_indexmap' indexmap_filename += '_{}_indexmap'.format(name)
indexmap_filename += '_{}ep'.format(num_epochs) if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples) indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length) indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed) indexmap_filename += '_{}s'.format(seed)
...@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset, ...@@ -120,8 +238,9 @@ def get_samples_mapping_(indexed_dataset,
# Build the indexed mapping if not exist. # Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \ if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename): not os.path.isfile(indexmap_filename):
print('WARNING: could not find index map file {}, building ' print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename)) 'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types. # Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64 assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32 assert indexed_dataset.sizes.dtype == np.int32
...@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -129,6 +248,8 @@ def get_samples_mapping_(indexed_dataset,
# Build samples mapping # Build samples mapping
verbose = torch.distributed.get_rank() == 0 verbose = torch.distributed.get_rank() == 0
start_time = time.time() start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
samples_mapping = helpers.build_mapping( samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
...@@ -138,21 +259,30 @@ def get_samples_mapping_(indexed_dataset, ...@@ -138,21 +259,30 @@ def get_samples_mapping_(indexed_dataset,
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping # Make sure all the ranks have built the mapping
print_rank_0('> elasped time to build and save samples mapping ' print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format( '(seconds): {:4f}'.format(
time.time() - start_time)) time.time() - start_time))
torch.distributed.barrier() # This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset. # Load indexed dataset.
print_rank_0('> loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename)) indexmap_filename))
start_time = time.time() start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True) samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0])) samples_mapping.shape[0]))
return samples_mapping return samples_mapping
......
...@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -39,12 +39,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
*/ */
if (verbose) {
cout << " > using " << docs_.shape(0) - 1 <<
" documents with " << sizes_.shape(0) << " sentences ..." <<
endl << std::flush;
}
// Consistency checks. // Consistency checks.
assert(num_epochs > 0); assert(num_epochs > 0);
assert(max_seq_length > 1); assert(max_seq_length > 1);
...@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -52,16 +46,36 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
assert(short_seq_prob <= 1.0); assert(short_seq_prob <= 1.0);
assert(seed > 0); assert(seed > 0);
// For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
// Remove bound checks. // Remove bound checks.
auto docs = docs_.unchecked<1>(); auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
if (docs[docs.shape(0) - 1] != sizes.shape(0)) {
cout << "document values is not consistent with length of sizes: " << // For efficiency, convert probability to ratio. Note: rand() generates int.
docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl; const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
throw std::length_error("docs and sizes");
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " short sequence probability: " << short_seq_prob <<
endl << std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
} }
// Mapping and it's length (1D). // Mapping and it's length (1D).
...@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -90,7 +104,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
for (int32_t epoch=0; epoch<num_epochs; ++epoch) { for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
if (map_index >= max_num_samples) { if (map_index >= max_num_samples) {
if (verbose && (!second)) { if (verbose && (!second)) {
cout << " > reached " << max_num_samples << " samples after " cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush; << epoch << " epochs ..." << endl << std::flush;
} }
break; break;
...@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -181,11 +195,11 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
if (!second) { if (!second) {
if (verbose) { if (verbose) {
cout << " > number of empty documents: " << empty_docs << cout << " number of empty documents: " << empty_docs <<
endl << std::flush; endl << std::flush;
cout << " > number of documents with one sentence: " << cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush; one_sent_docs << endl << std::flush;
cout << " > will create mapping for " << map_index << cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush; " samples" << endl << std::flush;
} }
assert(maps == NULL); assert(maps == NULL);
...@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -210,10 +224,6 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
swap(maps[i0 + 2], maps[j0 + 2]); swap(maps[i0 + 2], maps[j0 + 2]);
} }
if (verbose) {
cout << "> done building the mapping." << endl;
}
// Method to deallocate memory. // Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) { py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_); DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
...@@ -239,34 +249,20 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -239,34 +249,20 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int seed, const int seed,
const bool verbose) { const bool verbose) {
if (verbose) {
cout << "> building sample map using: " << endl << std::flush;
cout << " number of epochs: " << num_epochs << endl
<< std::flush;
cout << " maximum number of samples: " << max_num_samples << endl
<< std::flush;
cout << " maximum sequence length: " << max_seq_length << endl
<< std::flush;
cout << " short sequence probability: " << short_seq_prob << endl
<< std::flush;
cout << " seed: " << seed << endl
<< std::flush;
}
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
cout << " > using uint64 for data mapping..." << endl << std::flush; cout << " using uint64 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose);
} else { } else {
if (verbose) { if (verbose) {
cout << " > using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose);
} }
} }
......
...@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -391,17 +391,17 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
offset = stream.tell() offset = stream.tell()
if not skip_warmup: if not skip_warmup:
print_rank_0("> Warming up index mmap file...") print_rank_0(" warming up index mmap file...")
_warmup_mmap_file(path) _warmup_mmap_file(path)
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0("> Reading sizes...") print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
print_rank_0("> Reading pointers...") print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes) offset=offset + self._sizes.nbytes)
print_rank_0("> Reading document index...") print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes) offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self): def __del__(self):
...@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -447,13 +447,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._index = self.Index(index_file_path(self._path), skip_warmup) self._index = self.Index(index_file_path(self._path), skip_warmup)
if not skip_warmup: if not skip_warmup:
print_rank_0("> Warming up data mmap file...") print_rank_0(" warming up data mmap file...")
_warmup_mmap_file(data_file_path(self._path)) _warmup_mmap_file(data_file_path(self._path))
print_rank_0("> Creating numpy buffer of mmap...") print_rank_0(" creating numpy buffer of mmap...")
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C') self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
print_rank_0("> Creating memory view of numpy buffer...") print_rank_0(" creating memory view of numpy buffer...")
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0("> Done")
def __del__(self): def __del__(self):
self._bin_buffer_mmap._mmap.close() self._bin_buffer_mmap._mmap.close()
...@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -470,7 +469,6 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
if self._index.dtype != np.int64: if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64) np_array = np_array.astype(np.int64)
return np_array return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
...@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -492,6 +490,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def doc_idx(self): def doc_idx(self):
return self._index.doc_idx return self._index.doc_idx
def get_doc_idx(self):
return self._index._doc_idx
def set_doc_idx(self, doc_idx_):
self._index._doc_idx = doc_idx_
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return False return False
......
...@@ -13,43 +13,34 @@ ...@@ -13,43 +13,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""dataset to split one large one into multiple smaller datasets""" """dataset to split one large one into multiple smaller datasets"""
import torch import torch
import numpy as np import numpy as np
def should_split(split): def get_train_valid_test_split(splits_string, size):
""" """ Get dataset splits from comma or '/' separated string list."""
given split proportions checks if should split
Examples:
>>> should_split([10,0,0])
False
>>> should_split([1,.1,.2])
True
"""
return max(split)/sum(split) != 1.
def get_split(args):
"""
Get dataset splits from comma separated string list
"""
splits = [] splits = []
if args.split.find(',') != -1: if splits_string.find(',') != -1:
splits = [float(s) for s in args.split.split(',')] splits = [float(s) for s in splits_string.split(',')]
elif args.split.find('/') != -1: elif splits_string.find('/') != -1:
splits = [float(s) for s in args.split.split('/')] splits = [float(s) for s in splits_string.split('/')]
else: else:
splits = [float(args.split)] splits = [float(splits_string)]
split_total = sum(splits)
if split_total < 1.:
splits.append(1-split_total)
while len(splits) < 3: while len(splits) < 3:
splits.append(0.) splits.append(0.)
splits = splits[:3] splits = splits[:3]
if args.valid_data is not None: splits_sum = sum(splits)
splits[1] = 0. assert splits_sum > 0.0
if args.test_data is not None: splits = [split/splits_sum for split in splits]
splits[2] = 0. splits_index = [0]
final_sum = sum(splits) for index, split in enumerate(splits):
return [s/final_sum for s in splits] splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
return splits_index
class SplitDataset(torch.utils.data.Dataset): class SplitDataset(torch.utils.data.Dataset):
""" """
......
...@@ -13,21 +13,21 @@ ...@@ -13,21 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Pretrain BERT""" """Pretrain ALBERT"""
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from configure_data import configure_data
from megatron import mpu from megatron import mpu
from megatron.model import BertModel from megatron.model import BertModel
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run from megatron.training import run
from megatron.data import AlbertDataset, split_dataset from megatron.data.albert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler from megatron.data_utils.samplers import DistributedBatchSampler
def model_provider(args): def model_provider(args):
"""Build the model.""" """Build the model."""
...@@ -109,94 +109,98 @@ def forward_step(data_iterator, model, args, timers): ...@@ -109,94 +109,98 @@ def forward_step(data_iterator, model, args, timers):
def get_train_val_test_data(args): def get_train_val_test_data(args):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, val_data, test_data) = (None, None, None) (train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
if args.data_loader == None: print_rank_0('> building train, validation, and test datasets '
'for ALBERT ...')
if args.data_loader is None:
args.data_loader = 'binary' args.data_loader = 'binary'
if args.data_loader == 'binary': if args.data_loader != 'binary':
if not args.max_num_samples: print('Unsupported {} data loader for ALBERT.'.format(
args.max_num_samples = (args.train_iters + 2 * args.eval_iters) * args.batch_size args.data_loader))
if not args.data_path: exit(1)
print("Albert currently only supports a unified dataset specified with --data-path") if not args.data_path:
exit(1) print('ALBERT only supports a unified dataset specified '
print_rank_0("Creating AlbertDataset...") 'with --data-path')
full_data = AlbertDataset(
vocab_file=args.vocab,
data_prefix=args.data_path,
data_impl=args.data_impl,
skip_warmup=args.skip_mmap_warmup,
num_epochs=args.data_epochs,
max_num_samples=args.max_num_samples,
masked_lm_prob=args.mask_prob,
max_seq_length=args.seq_length,
short_seq_prob=args.short_seq_prob,
seed=args.seed)
print_rank_0("Finished creating AlbertDataset...")
split = split_dataset.get_split(args)
if split_dataset.should_split(split):
train_ds, val_ds, test_ds = split_dataset.split_ds(full_data, split, args.shuffle)
else:
train_ds = full_data
num_tokens = train_ds.num_tokens()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(val_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens for now
int(do_train),
int(do_valid),
int(do_test)])
else:
print("Unsupported data loader for BERT.")
exit(1) exit(1)
data_parallel_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [args.train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab,
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=args.skip_mmap_warmup)
print_rank_0("> finished creating ALBERT datasets ...")
def make_data_loader_(dataset):
if not dataset:
return None
# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(
sampler=sampler,
batch_size=global_batch_size,
drop_last=True,
rank=data_parallel_rank,
world_size=data_parallel_size)
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True)
train_data = make_data_loader_(train_ds)
valid_data = make_data_loader_(valid_ds)
test_data = make_data_loader_(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
num_tokens = vocab_size_with_padding(train_ds.num_tokens(), args)
token_counts = torch.cuda.LongTensor([num_tokens,
2, # hard coded num_type_tokens
int(do_train),
int(do_valid),
int(do_test)])
else: else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0]) token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(token_counts, torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(), mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item() args.vocab_size = token_counts[0].item()
num_type_tokens = token_counts[1].item() args.tokentype_size = token_counts[1].item()
args.do_train = token_counts[2].item() args.do_train = token_counts[2].item()
args.do_valid = token_counts[3].item() args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item() args.do_test = token_counts[4].item()
args.vocab_size = num_tokens return train_data, valid_data, test_data
args.tokentype_size = num_type_tokens
return train_data, val_data, test_data
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment