Commit 39181113 authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Last epoch should not be globally shuffled

parent 56243e19
...@@ -219,9 +219,47 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -219,9 +219,47 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
print_rank_0(' > WARNING: could not find index map files, building ' print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...') 'the indices on rank 0 ...')
# For the last epoch, decide whether include the entire epoch
# in the global shuffle or not.
# If we need only one epoch, then separating last epoch does
# not mean anything.
if num_epochs == 1:
separate_last_epoch = False
print(' > only one epoch required, setting '
'separate_last_epoch to False', flush=True)
else:
# Get the number of samples for the last epoch
num_samples_from_epochs_minus_one = (
(num_epochs - 1) * tokens_per_epoch - 1) // seq_length
last_epoch_num_samples = num_samples - \
num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, \
'last epoch number of samples should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
'last epoch number of samples exceeded max value.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
separate_last_epoch = (last_epoch_num_samples <
int(0.80 * num_samples_per_epoch))
if separate_last_epoch:
string = ' > last epoch number of samples ({}) is smaller '\
'than 80% of number of samples per epoch ({}), '\
'setting separate_last_epoch to True'
else:
string = ' > last epoch number of samples ({}) is larger '\
'than 80% of number of samples per epoch ({}), '\
'setting separate_last_epoch to False'
print(string.format(last_epoch_num_samples,
num_samples_per_epoch), flush=True)
# doc-idx. # doc-idx.
start_time = time.time() start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng) doc_idx = _build_doc_idx(documents, num_epochs, np_rng,
separate_last_epoch)
np.save(doc_idx_filename, doc_idx, allow_pickle=True) np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save doc-idx mapping ' print_rank_0(' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time)) '(seconds): {:4f}'.format(time.time() - start_time))
...@@ -245,7 +283,12 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -245,7 +283,12 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
start_time = time.time() start_time = time.time()
# -1 is due to data structure used to retieve the index: # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1]) # sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) if separate_last_epoch:
num_samples_ = num_samples_from_epochs_minus_one
else:
num_samples_ = sample_idx.shape[0] - 1
shuffle_idx = _build_shuffle_idx(num_samples_,
sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping' print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time)) ' (seconds): {:4f}'.format(time.time() - start_time))
...@@ -300,15 +343,20 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples): ...@@ -300,15 +343,20 @@ def _num_epochs(tokens_per_epoch, seq_length, num_samples):
return num_epochs return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng): def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch):
"""Build an array with length = number-of-epochs * number-of-dcuments. """Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document.""" Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] if not separate_last_epoch or num_epochs == 1:
doc_idx[:] = documents doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx = doc_idx.reshape(-1) doc_idx[:] = documents
doc_idx = doc_idx.astype(np.int32) doc_idx = doc_idx.reshape(-1)
np_rng.shuffle(doc_idx) doc_idx = doc_idx.astype(np.int32)
return doc_idx np_rng.shuffle(doc_idx)
return doc_idx
doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False)
doc_idx_last = _build_doc_idx(documents, 1, np_rng, False)
return np.concatenate((doc_idx_first, doc_idx_last))
def _build_sample_idx(sizes, doc_idx, seq_length, def _build_sample_idx(sizes, doc_idx, seq_length,
...@@ -360,11 +408,23 @@ def _build_sample_idx(sizes, doc_idx, seq_length, ...@@ -360,11 +408,23 @@ def _build_sample_idx(sizes, doc_idx, seq_length,
return sample_idx return sample_idx
def _build_shuffle_idx(size, np_rng): def def _build_shuffle_idx(num_samples, total_size, np_rng):
"""Build the range [0, size) and shuffle.""" """Build the range [0, size) and shuffle."""
print(' > building shuffle index with split [0, {}) and [{}, {}) '
'...'.format(num_samples, num_samples, total_size), flush=True)
dtype_ = np.uint32 dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1): if total_size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64 dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx) shuffle_idx_first = np.arange(start=0, stop=num_samples,
return shuffle_idx step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = np.arange(start=num_samples, stop=total_size,
step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx_last)
return np.concatenate((shuffle_idx_first, shuffle_idx_last))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment