# coding=utf-8 # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """BERT Style dataset.""" import os import time import numpy as np import torch from torch.utils.data import Dataset from megatron import mpu from megatron.data import helpers from megatron.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer from megatron.data.dataset_utils import build_training_sample from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron import print_rank_0 def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, splits_string, train_valid_test_num_samples, max_seq_length, masked_lm_prob, short_seq_prob, seed, skip_warmup): # Tokenizer is the same tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True) print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format( tokenizer.vocab_size())) # Indexed dataset. indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is desinged to be num-docs + 1 so we can # easily iterate over it. total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1 splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. print_rank_0(' > dataset split:') def print_split_stats(name, index): print_rank_0(' {}:'.format(name)) print_rank_0(' document indices in [{}, {}) total of {} ' 'documents'.format(splits[index], splits[index + 1], splits[index + 1] - splits[index])) start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] print_rank_0(' sentence indices in [{}, {}) total of {} ' 'sentences'.format(start_index, end_index, end_index - start_index)) print_split_stats('train', 0) print_split_stats('validation', 1) print_split_stats('test', 2) def build_dataset(index, name): dataset = None if splits[index + 1] > splits[index]: # Get the pointer to the original doc-idx so we can set it later. doc_idx_ptr = indexed_dataset.get_doc_idx() # Slice the doc-idx start_index = splits[index] # Add +1 so we can index into the dataset to get the upper bound. end_index = splits[index + 1] + 1 # New doc_idx view. indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) # Build the dataset accordingly. dataset = BertDataset( name=name, indexed_dataset=indexed_dataset, tokenizer=tokenizer, data_prefix=data_prefix, num_epochs=None, max_num_samples=train_valid_test_num_samples[index], masked_lm_prob=masked_lm_prob, max_seq_length=max_seq_length, short_seq_prob=short_seq_prob, seed=seed) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) # Checks. assert indexed_dataset.doc_idx[0] == 0 assert indexed_dataset.doc_idx.shape[0] == \ (total_num_of_documents + 1) return dataset train_dataset = build_dataset(0, 'train') valid_dataset = build_dataset(1, 'valid') test_dataset = build_dataset(2, 'test') return (train_dataset, valid_dataset, test_dataset) class BertDataset(Dataset): def __init__(self, name, indexed_dataset, tokenizer, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): # Params to store. self.name = name self.seed = seed self.masked_lm_prob = masked_lm_prob self.max_seq_length = max_seq_length # Tokenizer and dataset. self.tokenizer = tokenizer self.indexed_dataset = indexed_dataset # Build the samples mapping. self.samples_mapping = get_samples_mapping_(self.indexed_dataset, data_prefix, num_epochs, max_num_samples, self.max_seq_length, short_seq_prob, self.seed, self.name) # Vocab stuff. self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_dict = self.tokenizer.inv_vocab self.cls_id = self.tokenizer.vocab['[CLS]'] self.sep_id = self.tokenizer.vocab['[SEP]'] self.mask_id = self.tokenizer.vocab['[MASK]'] self.pad_id = self.tokenizer.vocab['[PAD]'] def num_tokens(self): return self.tokenizer.vocab_size() def __len__(self): return self.samples_mapping.shape[0] def __getitem__(self, idx): start_index, end_index, seq_length = self.samples_mapping[idx] sample = [] for index in range(start_index, end_index): sample.append(self.indexed_dataset[index]) # Note that this rng state should be numpy and not python since # python randint is inclusive whereas the numpy one is exclusive. np_rng = np.random.RandomState(seed=(self.seed + idx)) return build_training_sample(sample, seq_length, self.max_seq_length, # needed for padding self.vocab_id_list, self.vocab_id_to_token_dict, self.cls_id, self.sep_id, self.mask_id, self.pad_id, self.masked_lm_prob, np_rng) def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): print_rank_0(' > building dataset index ...') start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] print_rank_0(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) print_rank_0(' > indexed dataset stats:') print_rank_0(' number of documents: {}'.format( indexed_dataset.doc_idx.shape[0] - 1)) print_rank_0(' number of sentences: {}'.format( indexed_dataset.sizes.shape[0])) return indexed_dataset def get_train_valid_test_split_(splits_string, size): """ Get dataset splits from comma or '/' separated string list.""" splits = [] if splits_string.find(',') != -1: splits = [float(s) for s in splits_string.split(',')] elif splits_string.find('/') != -1: splits = [float(s) for s in splits_string.split('/')] else: splits = [float(splits_string)] while len(splits) < 3: splits.append(0.) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split/splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff assert len(splits_index) == 4 assert splits_index[-1] == size return splits_index def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name): if not num_epochs: if not max_num_samples: raise ValueError("Need to specify either max_num_samples " "or num_epochs") num_epochs = np.iinfo(np.int32).max - 1 if not max_num_samples: max_num_samples = np.iinfo(np.int64).max - 1 # Filename of the index mapping indexmap_filename = data_prefix indexmap_filename += '_{}_indexmap'.format(name) if num_epochs != (np.iinfo(np.int32).max - 1): indexmap_filename += '_{}ep'.format(num_epochs) if max_num_samples != (np.iinfo(np.int64).max - 1): indexmap_filename += '_{}mns'.format(max_num_samples) indexmap_filename += '_{}msl'.format(max_seq_length) indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) indexmap_filename += '_{}s'.format(seed) indexmap_filename += '.npy' # Build the indexed mapping if not exist. if torch.distributed.get_rank() == 0 and \ not os.path.isfile(indexmap_filename): print(' > WARNING: could not find index map file {}, building ' 'the indices on rank 0 ...'.format(indexmap_filename)) # Make sure the types match the helpers input types. assert indexed_dataset.doc_idx.dtype == np.int64 assert indexed_dataset.sizes.dtype == np.int32 # Build samples mapping verbose = torch.distributed.get_rank() == 0 start_time = time.time() print_rank_0(' > building sapmles index mapping for {} ...'.format( name)) samples_mapping = helpers.build_mapping( indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, max_num_samples, max_seq_length-3, # account for added tokens short_seq_prob, seed, verbose) print_rank_0(' > done building sapmles index maping') np.save(indexmap_filename, samples_mapping, allow_pickle=True) print_rank_0(' > saved the index mapping in {}'.format( indexmap_filename)) # Make sure all the ranks have built the mapping print_rank_0(' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format( time.time() - start_time)) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model # parallel case counts = torch.cuda.LongTensor([1]) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) assert counts[0].item() == torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( indexmap_filename)) start_time = time.time() samples_mapping = np.load(indexmap_filename, allow_pickle=True) print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( samples_mapping.shape[0])) return samples_mapping