Commit eb74fa34 authored by Mohammad's avatar Mohammad
Browse files

fixed a bug related to last index in the shuffled_idx

parent 5f174c07
...@@ -104,7 +104,9 @@ class GPT2Dataset(torch.utils.data.Dataset): ...@@ -104,7 +104,9 @@ class GPT2Dataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.sample_idx.shape[0] # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -185,7 +187,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -185,7 +187,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
'(seconds): {:4f}'.format(time.time() - start_time)) '(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx. # shuffle-idx.
start_time = time.time() start_time = time.time()
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0], np_rng) # -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0]-1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping' print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time)) ' (seconds): {:4f}'.format(time.time() - start_time))
...@@ -306,91 +310,3 @@ def _build_shuffle_idx(size, np_rng): ...@@ -306,91 +310,3 @@ def _build_shuffle_idx(size, np_rng):
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx) np_rng.shuffle(shuffle_idx)
return shuffle_idx return shuffle_idx
'''
class IndexedDataset:
def __init__(self, num_docs, min_doc_length, max_doc_length, seq_length):
self.seq_length = seq_length
assert min_doc_length > 0
self.tokens = []
self.sizes = np.zeros(num_docs, dtype=np.int32)
for i in range(num_docs):
size = np.random.randint(low=min_doc_length, high=max_doc_length,
size=1, dtype=np.uint32)[0]
tokens_ = np.random.randint(low=1, high=60000,
size=size, dtype=np.uint32)
tokens_[-1] = 0
self.sizes[i] = size
self.tokens.append(tokens_)
self.tokens_flat = None
def get(self, doc_idx, offset=None, length=None):
if length is None:
if offset is None:
return self.tokens[doc_idx]
else:
return self.tokens[doc_idx][offset:]
if offset is None:
return self.tokens[doc_idx][0:length]
return self.tokens[doc_idx][offset:(offset+length)]
def get_sample(self, index):
start = index * self.seq_length
end = start + self.seq_length + 1
return self.tokens_flat[start:end]
def build_tokens_flat(self, doc_idx):
self.tokens_flat = np.concatenate([self.tokens[i] for i in doc_idx])
def test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length):
print('testing for seed: {}, seq-length: {}, num-samples: {}, '
'num-docs: {}, min-doc-length: {}, max-doc-length: {}'.format(
seed, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length))
np.random.seed(seed)
indexed_dataset = IndexedDataset(num_docs, min_doc_length,
max_doc_length, seq_length)
indices = np.random.randint(indexed_dataset.sizes.shape[0]-2, size=2)
documents = np.arange(np.min(indices), np.max(indices)+1)
dataset = GPT2Dataset('gpt2', data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed)
print(' > number of epochs:', dataset.num_epochs)
indexed_dataset.build_tokens_flat(dataset.doc_idx)
for idx in range(num_samples):
a = dataset[idx]
b = indexed_dataset.get_sample(idx)
assert np.sum(a - b) == 0
print('passed')
if __name__ == '__main__':
print('gpt2 dataset ...')
import random
data_prefix = 'junk/'
for seed in range(1234, 1245):
random.seed(seed)
num_docs = random.randint(1, 999)
min_doc_length = random.randint(1, 99)
max_doc_length = random.randint(100, 9999)
num_samples = random.randint(num_docs, 100*num_docs)
seq_length = random.randint(min_doc_length, max_doc_length)
test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length)
'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment