Commit 898fcb94 authored by Mohammad's avatar Mohammad
Browse files

moved current gpt2 dataloader to old

parent c3691006
...@@ -13,124 +13,384 @@ ...@@ -13,124 +13,384 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 dataset.""" """GPT2 style dataset."""
import json
import os import os
import numpy as np import time
import numpy as np
import torch import torch
from torch.utils.data import Dataset
from megatron import print_rank_0
from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
class GPT2Dataset(Dataset):
def __init__(self, data_path, sizes_filename, seq_length, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
initial_seed, max_epochs=100): train_valid_test_num_samples,
# Input parameters. seq_length, seed, skip_warmup):
self.data_path = data_path """Build train, valid, and test datasets."""
self.sizes_filename = sizes_filename
self.seq_length = seq_length # Indexed dataset.
self.initial_seed = initial_seed indexed_dataset = get_indexed_dataset_(data_prefix,
self.max_epochs = max_epochs data_impl,
skip_warmup)
# Shard stuff.
# Dictionary from shard nameto its size (number of element). total_num_of_documents = indexed_dataset.sizes.shape[0]
self.master_shard_size_dict = None splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Dictionary from shard name to modified size so it is
# divisible by self.seq_length. # Print stats about the splits.
self.shard_size_dict = None print_rank_0(' > dataset split:')
# Long array (self.max_epochs * num-shards) populated def print_split_stats(name, index):
# randomly with shard names. print_rank_0(' {}:'.format(name))
self.shards_name = None print_rank_0(' document indices in [{}, {}) total of {} '
# Start index of the data for a shard. 'documents'.format(splits[index], splits[index + 1],
self.shards_start_index = None splits[index + 1] - splits[index]))
self.build_shard_mappings_() print_split_stats('train', 0)
self.data_length = self.shards_start_index[-1] print_split_stats('validation', 1)
print_split_stats('test', 2)
# Data.
self.shards_data = [None]*self.shards_name.size def build_dataset(index, name):
self.shards_sample_index = [None]*self.shards_name.size dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index+1],
step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self): def __len__(self):
return self.data_length return self.sample_idx.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
# Find which shard we need. # Get the shuffled index.
shard_index = np.searchsorted(self.shards_start_index, idx = self.shuffle_idx[idx]
idx, side='right') - 1 # Start and end documents and offsets.
# data index in the shard. doc_index_f = self.sample_idx[idx][0]
data_idx = idx - self.shards_start_index[shard_index] doc_index_l = self.sample_idx[idx+1][0]
# Load the shard if it is not in memory. offset_f = self.sample_idx[idx][1]
if self.shards_data[shard_index] is None: offset_l = self.sample_idx[idx+1][1]
print('global rank {} is building data for shard index {} ...'. # If we are within the same document, just extract the chunk.
format(torch.distributed.get_rank(), shard_index)) if doc_index_f == doc_index_l:
self.build_dataset_(shard_index) sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
#assert self.shards_data[shard_index] is not None offset=offset_f,
# Start index. length=offset_l - offset_f + 1)
start_index = self.shards_sample_index[shard_index][data_idx] else:
# Add one for label shift. # Otherwise, get the rest of the initial document.
end_index = start_index + self.seq_length + 1 sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
data = self.shards_data[shard_index][start_index:end_index] offset=offset_f)]
return {'text': np.array(data, dtype=np.int64)} # Loop over all in between documents and add the entire document.
for i in range(doc_index_f+1, doc_index_l):
def build_dataset_(self, shard_index): sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# Garbage collect so we don't use a lot of memory. # And finally add the relevant portion of last document.
# Leave the last one in case other threads have not catche up yet. sample_list.append(self.indexed_dataset.get(
#for i in range(shard_index - 1): self.doc_idx[doc_index_l],
for i in range(shard_index): length=offset_l+1))
self.shards_data[i] = None sample = np.concatenate(sample_list)
self.shards_sample_index[i] = None
# Read the shard. return {'text': np.array(sample, dtype=np.int64)}
filename = os.path.join(self.data_path, self.shards_name[shard_index])
print('loading {}'.format(filename))
data = np.load(filename, allow_pickle=True)
# Shuffle the data def _build_index_mappings(name, data_prefix, documents, sizes,
rng = np.random.RandomState(self.initial_seed + shard_index) num_samples, seq_length, seed):
rng.shuffle(data) """doc-idx, sample-idx, and shuffle-idx."""
# Flatten. # Number of tokens in each epoch and number of required epochs.
data = np.hstack(data) tokens_per_epoch = _num_tokens(documents, sizes)
size = (data.shape[0] - 1) // self.seq_length num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
last_index = size * self.seq_length + 1 # rng state
data = data[0:last_index] np_rng = np.random.RandomState(seed=seed)
self.shards_data[shard_index] = data
indices = np.arange(size) * self.seq_length # Filename of the index mappings.
rng.shuffle(indices) _filename = data_prefix
self.shards_sample_index[shard_index] = indices _filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
def build_shard_mappings_(self): _filename += '_{}sl'.format(seq_length)
# Load the sizes file. _filename += '_{}s'.format(seed)
sizes_filename = os.path.join(self.data_path, self.sizes_filename) doc_idx_filename = _filename + '_doc_idx.npy'
if torch.distributed.get_rank() == 0: sample_idx_filename = _filename + '_sample_idx.npy'
print(' > loading sizes from {}'.format(sizes_filename)) shuffle_idx_filename = _filename + '_shuffle_idx.npy'
with open(sizes_filename, 'r') as f:
self.master_shard_size_dict = json.load(f) # Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print(' found {} shards'.format(len(self.master_shard_size_dict))) if (not os.path.isfile(doc_idx_filename)) or \
# Adjust sizes to be a multiple of seq_length. (not os.path.isfile(sample_idx_filename)) or \
self.shard_size_dict = self.master_shard_size_dict.copy() (not os.path.isfile(shuffle_idx_filename)):
total_samples = 0
for shard in self.shard_size_dict: print_rank_0(' > WARNING: could not find index map files, building '
size = self.shard_size_dict[shard] 'the indices on rank 0 ...')
size = ((size - 1) // self.seq_length) * self.seq_length # doc-idx.
total_samples += size // self.seq_length start_time = time.time()
self.shard_size_dict[shard] = size doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
if torch.distributed.get_rank() == 0: np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print(' found {} samples in the dataset'.format(total_samples)) print_rank_0(' > elasped time to build and save doc-idx mapping '
# Build a list of shards. '(seconds): {:4f}'.format(time.time() - start_time))
shards_ = np.sort(np.array(list(self.shard_size_dict.keys()))) # sample-idx.
rng = np.random.RandomState(self.initial_seed) start_time = time.time()
self.shards_name = np.copy(shards_) # Use C++ implementation for speed.
rng.shuffle(self.shards_name) from megatron.data import helpers
for i in range(1, self.max_epochs): assert doc_idx.dtype == np.int32
shards_c = np.copy(shards_) assert sizes.dtype == np.int32
rng.shuffle(shards_c) sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
self.shards_name = np.append(self.shards_name, shards_c) num_epochs, tokens_per_epoch)
# Build the global indexing. #sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
self.shards_start_index = np.zeros(self.shards_name.size, dtype=np.int) # num_epochs, tokens_per_epoch)
self.shards_start_index[0] = 0 np.save(sample_idx_filename, sample_idx, allow_pickle=True)
for i in range(1, self.shards_name.size): print_rank_0(' > elasped time to build and save sample-idx mapping '
shard = str(self.shards_name[i-1]) '(seconds): {:4f}'.format(time.time() - start_time))
size = self.shard_size_dict[shard] # shuffle-idx.
self.shards_start_index[i] = self.shards_start_index[i-1] + \ start_time = time.time()
size // self.seq_length shuffle_idx = _build_shuffle_idx(sample_idx.shape[0], np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True)
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True)
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
def _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 0] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(size, np_rng):
"""Build the range [0, size) and shuffle."""
dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx)
return shuffle_idx
'''
class IndexedDataset:
def __init__(self, num_docs, min_doc_length, max_doc_length, seq_length):
self.seq_length = seq_length
assert min_doc_length > 0
self.tokens = []
self.sizes = np.zeros(num_docs, dtype=np.int32)
for i in range(num_docs):
size = np.random.randint(low=min_doc_length, high=max_doc_length,
size=1, dtype=np.uint32)[0]
tokens_ = np.random.randint(low=1, high=60000,
size=size, dtype=np.uint32)
tokens_[-1] = 0
self.sizes[i] = size
self.tokens.append(tokens_)
self.tokens_flat = None
def get(self, doc_idx, offset=None, length=None):
if length is None:
if offset is None:
return self.tokens[doc_idx]
else:
return self.tokens[doc_idx][offset:]
if offset is None:
return self.tokens[doc_idx][0:length]
return self.tokens[doc_idx][offset:(offset+length)]
def get_sample(self, index):
start = index * self.seq_length
end = start + self.seq_length + 1
return self.tokens_flat[start:end]
def build_tokens_flat(self, doc_idx):
self.tokens_flat = np.concatenate([self.tokens[i] for i in doc_idx])
def test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length):
print('testing for seed: {}, seq-length: {}, num-samples: {}, '
'num-docs: {}, min-doc-length: {}, max-doc-length: {}'.format(
seed, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length))
np.random.seed(seed)
indexed_dataset = IndexedDataset(num_docs, min_doc_length,
max_doc_length, seq_length)
indices = np.random.randint(indexed_dataset.sizes.shape[0]-2, size=2)
documents = np.arange(np.min(indices), np.max(indices)+1)
dataset = GPT2Dataset('gpt2', data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed)
print(' > number of epochs:', dataset.num_epochs)
indexed_dataset.build_tokens_flat(dataset.doc_idx)
for idx in range(num_samples):
a = dataset[idx]
b = indexed_dataset.get_sample(idx)
assert np.sum(a - b) == 0
print('passed')
if __name__ == '__main__':
print('gpt2 dataset ...')
import random
data_prefix = 'junk/'
for seed in range(1234, 1245):
random.seed(seed)
num_docs = random.randint(1, 999)
min_doc_length = random.randint(1, 99)
max_doc_length = random.randint(100, 9999)
num_samples = random.randint(num_docs, 100*num_docs)
seq_length = random.randint(min_doc_length, max_doc_length)
test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length)
'''
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 style dataset."""
import os
import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index+1],
step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self):
return self.sample_idx.shape[0]
def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx+1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx+1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1)
else:
# Otherwise, get the rest of the initial document.
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f+1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l],
length=offset_l+1))
sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples, seq_length, seed):
"""doc-idx, sample-idx, and shuffle-idx."""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}sl'.format(seq_length)
_filename += '_{}s'.format(seed)
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
# doc-idx.
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# sample-idx.
start_time = time.time()
# Use C++ implementation for speed.
from megatron.data import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
#sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx.
start_time = time.time()
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0], np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True)
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True)
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
def _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 0] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(size, np_rng):
"""Build the range [0, size) and shuffle."""
dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx)
return shuffle_idx
'''
class IndexedDataset:
def __init__(self, num_docs, min_doc_length, max_doc_length, seq_length):
self.seq_length = seq_length
assert min_doc_length > 0
self.tokens = []
self.sizes = np.zeros(num_docs, dtype=np.int32)
for i in range(num_docs):
size = np.random.randint(low=min_doc_length, high=max_doc_length,
size=1, dtype=np.uint32)[0]
tokens_ = np.random.randint(low=1, high=60000,
size=size, dtype=np.uint32)
tokens_[-1] = 0
self.sizes[i] = size
self.tokens.append(tokens_)
self.tokens_flat = None
def get(self, doc_idx, offset=None, length=None):
if length is None:
if offset is None:
return self.tokens[doc_idx]
else:
return self.tokens[doc_idx][offset:]
if offset is None:
return self.tokens[doc_idx][0:length]
return self.tokens[doc_idx][offset:(offset+length)]
def get_sample(self, index):
start = index * self.seq_length
end = start + self.seq_length + 1
return self.tokens_flat[start:end]
def build_tokens_flat(self, doc_idx):
self.tokens_flat = np.concatenate([self.tokens[i] for i in doc_idx])
def test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length):
print('testing for seed: {}, seq-length: {}, num-samples: {}, '
'num-docs: {}, min-doc-length: {}, max-doc-length: {}'.format(
seed, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length))
np.random.seed(seed)
indexed_dataset = IndexedDataset(num_docs, min_doc_length,
max_doc_length, seq_length)
indices = np.random.randint(indexed_dataset.sizes.shape[0]-2, size=2)
documents = np.arange(np.min(indices), np.max(indices)+1)
dataset = GPT2Dataset('gpt2', data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed)
print(' > number of epochs:', dataset.num_epochs)
indexed_dataset.build_tokens_flat(dataset.doc_idx)
for idx in range(num_samples):
a = dataset[idx]
b = indexed_dataset.get_sample(idx)
assert np.sum(a - b) == 0
print('passed')
if __name__ == '__main__':
print('gpt2 dataset ...')
import random
data_prefix = 'junk/'
for seed in range(1234, 1245):
random.seed(seed)
num_docs = random.randint(1, 999)
min_doc_length = random.randint(1, 99)
max_doc_length = random.randint(100, 9999)
num_samples = random.randint(num_docs, 100*num_docs)
seq_length = random.randint(min_doc_length, max_doc_length)
test(seed, data_prefix, seq_length, num_samples,
num_docs, min_doc_length, max_doc_length)
'''
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 dataset."""
import json
import os
import numpy as np
import torch
from torch.utils.data import Dataset
class GPT2Dataset(Dataset):
def __init__(self, data_path, sizes_filename, seq_length,
initial_seed, max_epochs=100):
# Input parameters.
self.data_path = data_path
self.sizes_filename = sizes_filename
self.seq_length = seq_length
self.initial_seed = initial_seed
self.max_epochs = max_epochs
# Shard stuff.
# Dictionary from shard nameto its size (number of element).
self.master_shard_size_dict = None
# Dictionary from shard name to modified size so it is
# divisible by self.seq_length.
self.shard_size_dict = None
# Long array (self.max_epochs * num-shards) populated
# randomly with shard names.
self.shards_name = None
# Start index of the data for a shard.
self.shards_start_index = None
self.build_shard_mappings_()
self.data_length = self.shards_start_index[-1]
# Data.
self.shards_data = [None]*self.shards_name.size
self.shards_sample_index = [None]*self.shards_name.size
def __len__(self):
return self.data_length
def __getitem__(self, idx):
# Find which shard we need.
shard_index = np.searchsorted(self.shards_start_index,
idx, side='right') - 1
# data index in the shard.
data_idx = idx - self.shards_start_index[shard_index]
# Load the shard if it is not in memory.
if self.shards_data[shard_index] is None:
print('global rank {} is building data for shard index {} ...'.
format(torch.distributed.get_rank(), shard_index))
self.build_dataset_(shard_index)
#assert self.shards_data[shard_index] is not None
# Start index.
start_index = self.shards_sample_index[shard_index][data_idx]
# Add one for label shift.
end_index = start_index + self.seq_length + 1
data = self.shards_data[shard_index][start_index:end_index]
return {'text': np.array(data, dtype=np.int64)}
def build_dataset_(self, shard_index):
# Garbage collect so we don't use a lot of memory.
# Leave the last one in case other threads have not catche up yet.
#for i in range(shard_index - 1):
for i in range(shard_index):
self.shards_data[i] = None
self.shards_sample_index[i] = None
# Read the shard.
filename = os.path.join(self.data_path, self.shards_name[shard_index])
print('loading {}'.format(filename))
data = np.load(filename, allow_pickle=True)
# Shuffle the data
rng = np.random.RandomState(self.initial_seed + shard_index)
rng.shuffle(data)
# Flatten.
data = np.hstack(data)
size = (data.shape[0] - 1) // self.seq_length
last_index = size * self.seq_length + 1
data = data[0:last_index]
self.shards_data[shard_index] = data
indices = np.arange(size) * self.seq_length
rng.shuffle(indices)
self.shards_sample_index[shard_index] = indices
def build_shard_mappings_(self):
# Load the sizes file.
sizes_filename = os.path.join(self.data_path, self.sizes_filename)
if torch.distributed.get_rank() == 0:
print(' > loading sizes from {}'.format(sizes_filename))
with open(sizes_filename, 'r') as f:
self.master_shard_size_dict = json.load(f)
if torch.distributed.get_rank() == 0:
print(' found {} shards'.format(len(self.master_shard_size_dict)))
# Adjust sizes to be a multiple of seq_length.
self.shard_size_dict = self.master_shard_size_dict.copy()
total_samples = 0
for shard in self.shard_size_dict:
size = self.shard_size_dict[shard]
size = ((size - 1) // self.seq_length) * self.seq_length
total_samples += size // self.seq_length
self.shard_size_dict[shard] = size
if torch.distributed.get_rank() == 0:
print(' found {} samples in the dataset'.format(total_samples))
# Build a list of shards.
shards_ = np.sort(np.array(list(self.shard_size_dict.keys())))
rng = np.random.RandomState(self.initial_seed)
self.shards_name = np.copy(shards_)
rng.shuffle(self.shards_name)
for i in range(1, self.max_epochs):
shards_c = np.copy(shards_)
rng.shuffle(shards_c)
self.shards_name = np.append(self.shards_name, shards_c)
# Build the global indexing.
self.shards_start_index = np.zeros(self.shards_name.size, dtype=np.int)
self.shards_start_index[0] = 0
for i in range(1, self.shards_name.size):
shard = str(self.shards_name[i-1])
size = self.shard_size_dict[shard]
self.shards_start_index[i] = self.shards_start_index[i-1] + \
size // self.seq_length
...@@ -24,7 +24,7 @@ from megatron import get_timers ...@@ -24,7 +24,7 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
...@@ -98,56 +98,53 @@ def forward_step(data_iterator, model): ...@@ -98,56 +98,53 @@ def forward_step(data_iterator, model):
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': reduced_loss[0]}
def make_gpt2_dataloaders():
"""Build gpt2 dataloders."""
args = get_args()
# Input parameters.
input_data_sizes_file = args.input_data_sizes_file
seq_length = args.seq_length
initial_seed = args.seed
# Build the datasets.
def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = _build_dataset('train')
valid_ds = _build_dataset('valid')
test_ds = _build_dataset('test')
# Dataloaders
train = make_data_loader(train_ds)
valid = make_data_loader(valid_ds)
test = make_data_loader(test_ds)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
return (train, valid, test)
def get_train_val_test_data(): def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args() args = get_args()
(train_data, val_data, test_data) = (None, None, None) (train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
print_rank_0('> building train, validation, and test datasets '
(train_data, val_data, test_data) = make_gpt2_dataloaders() 'for GPT2 ...')
flags = torch.cuda.LongTensor([int(args.do_train),
int(args.do_valid), data_parallel_size = mpu.get_data_parallel_world_size()
int(args.do_test)]) data_parallel_rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating GPT2 datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else: else:
flags = torch.cuda.LongTensor([0, 0, 0]) flags = torch.cuda.LongTensor([0, 0, 0])
...@@ -159,7 +156,7 @@ def get_train_val_test_data(): ...@@ -159,7 +156,7 @@ def get_train_val_test_data():
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
return train_data, val_data, test_data return train_data, valid_data, test_data
if __name__ == "__main__": if __name__ == "__main__":
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT2"""
import os
import torch
from megatron import get_args
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.data.old_gpt2_dataset import GPT2Dataset
from megatron.model import GPT2Model
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
args.fp16)
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch generator').stop()
# Forward model.
output = model(tokens, position_ids, attention_mask)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
reduced_loss = reduce_losses([loss])
return loss, {'lm loss': reduced_loss[0]}
def make_gpt2_dataloaders():
"""Build gpt2 dataloders."""
args = get_args()
# Input parameters.
input_data_sizes_file = args.input_data_sizes_file
seq_length = args.seq_length
initial_seed = args.seed
# Build the datasets.
def _build_dataset(name):
return GPT2Dataset(os.path.join(args.data_path, name),
args.input_data_sizes_file,
args.seq_length, args.seed)
train_ds = _build_dataset('train')
valid_ds = _build_dataset('valid')
test_ds = _build_dataset('test')
# Dataloaders
train = make_data_loader(train_ds)
valid = make_data_loader(valid_ds)
test = make_data_loader(test_ds)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
return (train, valid, test)
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
(train_data, val_data, test_data) = make_gpt2_dataloaders()
flags = torch.cuda.LongTensor([int(args.do_train),
int(args.do_valid),
int(args.do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, val_data, test_data
if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment