Commit b03af49e authored by Neel Kant's avatar Neel Kant
Browse files

Hacks to build IndexedDataset and run pretrain

parent 4558e42f
...@@ -4,7 +4,6 @@ import numpy as np ...@@ -4,7 +4,6 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from .bert_dataset import get_samples_mapping_
class InverseClozeDataset(Dataset): class InverseClozeDataset(Dataset):
...@@ -18,7 +17,7 @@ class InverseClozeDataset(Dataset): ...@@ -18,7 +17,7 @@ class InverseClozeDataset(Dataset):
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
self.samples_mapping = get_samples_mapping_(self.indexed_dataset, self.samples_mapping = get_samples_mapping(self.indexed_dataset,
data_prefix, data_prefix,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
...@@ -160,3 +159,84 @@ class InverseClozeDataset(Dataset): ...@@ -160,3 +159,84 @@ class InverseClozeDataset(Dataset):
(context_tokens, context_token_types, context_pad_mask) (context_tokens, context_token_types, context_pad_mask)
else: else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset") raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
def get_samples_mapping(indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
short_seq_prob,
seed,
name):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert indexed_dataset.doc_idx.dtype == np.int64
assert indexed_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format(
name))
samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length-3, # account for added tokens
short_seq_prob,
seed,
verbose)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elasped time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
...@@ -6,9 +6,10 @@ import sys ...@@ -6,9 +6,10 @@ import sys
import time import time
import torch import torch
sys.path.insert(0, '../')
from bert_tokenization import FullTokenizer sys.path.insert(0, '../../')
import indexed_dataset from tokenizer.bert_tokenization import FullTokenizer
from data.indexed_dataset import make_builder
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
...@@ -23,6 +24,8 @@ class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): ...@@ -23,6 +24,8 @@ class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
))""" ))"""
class Encoder(object): class Encoder(object):
splitter = None
tokenizer = None
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
...@@ -32,7 +35,7 @@ class Encoder(object): ...@@ -32,7 +35,7 @@ class Encoder(object):
spliter = nltk.load("tokenizers/punkt/english.pickle") spliter = nltk.load("tokenizers/punkt/english.pickle")
if self.args.keep_newlines: if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences # this prevents punkt from eating newlines after sentences
Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer( Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = spliter._params, train_text = spliter._params,
lang_vars = CustomLanguageVars()) lang_vars = CustomLanguageVars())
else: else:
...@@ -82,7 +85,7 @@ def main(): ...@@ -82,7 +85,7 @@ def main():
output_bin_file = "{}.bin".format(args.output_prefix) output_bin_file = "{}.bin".format(args.output_prefix)
output_idx_file = "{}.idx".format(args.output_prefix) output_idx_file = "{}.idx".format(args.output_prefix)
builder = indexed_dataset.make_builder(output_bin_file, builder = make_builder(output_bin_file,
impl=args.dataset_impl, impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size()) vocab_size=tokenizer.vocab_size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment