Commit c3691006 authored by Mohammad's avatar Mohammad
Browse files

testing new gpt2 dataset

parent 836c6776
...@@ -24,7 +24,6 @@ from torch.utils.data import Dataset ...@@ -24,7 +24,6 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data import helpers
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -249,6 +248,7 @@ def get_samples_mapping_(indexed_dataset, ...@@ -249,6 +248,7 @@ def get_samples_mapping_(indexed_dataset,
start_time = time.time() start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format( print_rank_0(' > building sapmles index mapping for {} ...'.format(
name)) name))
from megatron.data import helpers
samples_mapping = helpers.build_mapping( samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
......
...@@ -13,26 +13,24 @@ ...@@ -13,26 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 Style dataset.""" """GPT2 style dataset."""
import os import os
import time import time
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import Dataset
import helpers from megatron import print_rank_0
#from bert_dataset import get_train_valid_test_split_ from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def print_rank_0(message):
print(message)
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
seq_length, seed, skip_warmup): seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset. # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, indexed_dataset = get_indexed_dataset_(data_prefix,
...@@ -56,7 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -56,7 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name): def build_dataset(index, name):
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], end=splits[index+1], documents = np.arange(start=splits[index], stop=splits[index+1],
step=1, dtype=np.int32) step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix, dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset, documents, indexed_dataset,
...@@ -72,7 +70,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -72,7 +70,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...') print_rank_0(' > building dataset index ...')
start_time = time.time() start_time = time.time()
...@@ -81,25 +79,18 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): ...@@ -81,25 +79,18 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
skip_warmup) skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} ' print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time)) 'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format( print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0])) indexed_dataset.sizes.shape[0]))
return indexed_dataset return indexed_dataset
class GPT2Dataset(Dataset): class GPT2Dataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, def __init__(self, name, data_prefix, documents, indexed_dataset,
documents, indexed_dataset,
num_samples, seq_length, seed): num_samples, seq_length, seed):
self.name = name self.name = name
self.data_prefix = data_prefix
self.num_samples = num_samples
self.seq_length = seq_length
self.seed = seed
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
# Checks # Checks
...@@ -107,11 +98,9 @@ class GPT2Dataset(Dataset): ...@@ -107,11 +98,9 @@ class GPT2Dataset(Dataset):
assert np.max(documents) < indexed_dataset.sizes.shape[0] assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings. # Build index mappings.
self.num_epochs, self.doc_idx, self.sample_idx, self.shuffle_idx \ self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
= _build_index_mappings(self.name, self.data_prefix, documents, self.name, data_prefix, documents, self.indexed_dataset.sizes,
self.indexed_dataset.sizes, num_samples, seq_length, seed)
self.num_samples, self.seq_length,
self.seed)
def __len__(self): def __len__(self):
...@@ -144,7 +133,7 @@ class GPT2Dataset(Dataset): ...@@ -144,7 +133,7 @@ class GPT2Dataset(Dataset):
length=offset_l+1)) length=offset_l+1))
sample = np.concatenate(sample_list) sample = np.concatenate(sample_list)
return sample return {'text': np.array(sample, dtype=np.int64)}
...@@ -168,7 +157,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -168,7 +157,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
shuffle_idx_filename = _filename + '_shuffle_idx.npy' shuffle_idx_filename = _filename + '_shuffle_idx.npy'
# Build the indexed mapping if not exist. # Build the indexed mapping if not exist.
if True: #torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \ if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \ (not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)): (not os.path.isfile(shuffle_idx_filename)):
...@@ -183,7 +172,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -183,7 +172,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
'(seconds): {:4f}'.format(time.time() - start_time)) '(seconds): {:4f}'.format(time.time() - start_time))
# sample-idx. # sample-idx.
start_time = time.time() start_time = time.time()
import helpers # Use C++ implementation for speed.
from megatron.data import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch) num_epochs, tokens_per_epoch)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length, #sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
...@@ -202,9 +194,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -202,9 +194,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# device_index=rank which is not the case for model # device_index=rank which is not the case for model
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
#torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
#assert counts[0].item() == torch.distributed.get_world_size( assert counts[0].item() == torch.distributed.get_world_size(
# group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
# Load mappings. # Load mappings.
start_time = time.time() start_time = time.time()
...@@ -221,8 +213,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes, ...@@ -221,8 +213,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0])) sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return num_epochs, doc_idx, sample_idx, shuffle_idx return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes): def _num_tokens(documents, sizes):
...@@ -311,10 +304,11 @@ def _build_shuffle_idx(size, np_rng): ...@@ -311,10 +304,11 @@ def _build_shuffle_idx(size, np_rng):
if size >= (np.iinfo(np.uint32).max - 1): if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64 dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
#np_rng.shuffle(shuffle_idx) np_rng.shuffle(shuffle_idx)
return shuffle_idx return shuffle_idx
'''
class IndexedDataset: class IndexedDataset:
...@@ -399,53 +393,4 @@ if __name__ == '__main__': ...@@ -399,53 +393,4 @@ if __name__ == '__main__':
test(seed, data_prefix, seq_length, num_samples, test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length) num_docs, min_doc_length, max_doc_length)
exit() '''
'''
num_docs = 5
min_doc_length = 2
max_doc_length = 10
num_samples = 9
seq_length = 4
seed = 1234
np.random.seed(seed)
indexed_dataset = IndexedDataset(num_docs, min_doc_length,
max_doc_length, seq_length)
print('> indexed dataset:')
for s in indexed_dataset.tokens:
print(' {}'.format(s))
documents = np.array([1,2,3], dtype=np.int32)
dataset = GPT2Dataset('gpt2', documents, indexed_dataset,
num_samples, seq_length, seed)
indexed_dataset.build_tokens_flat(dataset.doc_idx)
print(indexed_dataset.get_sample(6))
print(dataset[6])
'''
'''
myds = MyDataset(ds, num_samples, seq_length)
num_docs = myds._num_docs()
print('> number of document: {}'.format(num_docs))
tokens_per_epoch = myds._num_tokens()
print('> number of tokens: {}'.format(tokens_per_epoch))
num_epochs = myds._num_epochs(tokens_per_epoch)
print('> number of epochs: {}'.format(num_epochs))
doc_idx = myds._build_doc_idx(num_docs, num_epochs)
print('> doc_idx: {}'.format(doc_idx))
ds.build_tokens_flat(doc_idx)
sample_idx =myds._build_sample_idx(num_epochs, tokens_per_epoch, doc_idx)
for s in sample_idx:
print(s)
print(ds.tokens_flat)
print(myds.get_sample(8))
print(ds.get_sample(8))
'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment