# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """BERT Style dataset.""" import os import time import numpy as np import torch from torch.utils.data import Dataset from megatron import get_tokenizer from megatron import mpu from megatron.data.dataset_utils import build_training_sample from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron import print_rank_0 DATASET_TYPES = ['standard_bert', 'ict', 'realm'] class BertDataset(Dataset): def __init__(self, name, indexed_dataset, data_prefix, num_epochs, max_num_samples, masked_lm_prob, max_seq_length, short_seq_prob, seed): # Params to store. self.name = name self.seed = seed self.masked_lm_prob = masked_lm_prob self.max_seq_length = max_seq_length # Dataset. self.indexed_dataset = indexed_dataset # Build the samples mapping. self.samples_mapping = get_samples_mapping_(self.indexed_dataset, data_prefix, num_epochs, max_num_samples, self.max_seq_length, short_seq_prob, self.seed, self.name) # Vocab stuff. tokenizer = get_tokenizer() self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_to_token_dict = tokenizer.inv_vocab self.cls_id = tokenizer.cls self.sep_id = tokenizer.sep self.mask_id = tokenizer.mask self.pad_id = tokenizer.pad self.build_sample_fn = build_training_sample def __len__(self): return self.samples_mapping.shape[0] def __getitem__(self, idx): start_idx, end_idx, seq_length = self.samples_mapping[idx] sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)] # Note that this rng state should be numpy and not python since # python randint is inclusive whereas the numpy one is exclusive. np_rng = np.random.RandomState(seed=(self.seed + idx)) return self.build_sample_fn(sample, seq_length, self.max_seq_length, # needed for padding self.vocab_id_list, self.vocab_id_to_token_dict, self.cls_id, self.sep_id, self.mask_id, self.pad_id, self.masked_lm_prob, np_rng) def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): print_rank_0(' > building dataset index ...') start_time = time.time() indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] print_rank_0(' > finished creating indexed dataset in {:4f} ' 'seconds'.format(time.time() - start_time)) print_rank_0(' > indexed dataset stats:') print_rank_0(' number of documents: {}'.format( indexed_dataset.doc_idx.shape[0] - 1)) print_rank_0(' number of sentences: {}'.format( indexed_dataset.sizes.shape[0])) return indexed_dataset def get_train_valid_test_split_(splits_string, size): """ Get dataset splits from comma or '/' separated string list.""" splits = [] if splits_string.find(',') != -1: splits = [float(s) for s in splits_string.split(',')] elif splits_string.find('/') != -1: splits = [float(s) for s in splits_string.split('/')] else: splits = [float(splits_string)] while len(splits) < 3: splits.append(0.) splits = splits[:3] splits_sum = sum(splits) assert splits_sum > 0.0 splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff assert len(splits_index) == 4 assert splits_index[-1] == size return splits_index def get_samples_mapping_(indexed_dataset, data_prefix, num_epochs, max_num_samples, max_seq_length, short_seq_prob, seed, name): if not num_epochs: if not max_num_samples: raise ValueError("Need to specify either max_num_samples " "or num_epochs") num_epochs = np.iinfo(np.int32).max - 1 if not max_num_samples: max_num_samples = np.iinfo(np.int64).max - 1 # Filename of the index mapping indexmap_filename = data_prefix indexmap_filename += '_{}_indexmap'.format(name) if num_epochs != (np.iinfo(np.int32).max - 1): indexmap_filename += '_{}ep'.format(num_epochs) if max_num_samples != (np.iinfo(np.int64).max - 1): indexmap_filename += '_{}mns'.format(max_num_samples) indexmap_filename += '_{}msl'.format(max_seq_length) indexmap_filename += '_{:0.2f}ssp'.format(short_seq_prob) indexmap_filename += '_{}s'.format(seed) indexmap_filename += '.npy' # Build the indexed mapping if not exist. if torch.distributed.get_rank() == 0 and \ not os.path.isfile(indexmap_filename): print(' > WARNING: could not find index map file {}, building ' 'the indices on rank 0 ...'.format(indexmap_filename)) # Make sure the types match the helpers input types. assert indexed_dataset.doc_idx.dtype == np.int64 assert indexed_dataset.sizes.dtype == np.int32 # Build samples mapping verbose = torch.distributed.get_rank() == 0 start_time = time.time() print_rank_0(' > building sapmles index mapping for {} ...'.format( name)) # First compile and then import. from megatron.data.dataset_utils import compile_helper compile_helper() from megatron.data import helpers samples_mapping = helpers.build_mapping( indexed_dataset.doc_idx, indexed_dataset.sizes, num_epochs, max_num_samples, max_seq_length - 3, # account for added tokens short_seq_prob, seed, verbose) print_rank_0(' > done building sapmles index maping') np.save(indexmap_filename, samples_mapping, allow_pickle=True) print_rank_0(' > saved the index mapping in {}'.format( indexmap_filename)) # Make sure all the ranks have built the mapping print_rank_0(' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format( time.time() - start_time)) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model # parallel case counts = torch.cuda.LongTensor([1]) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) assert counts[0].item() == torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( indexmap_filename)) start_time = time.time() samples_mapping = np.load(indexmap_filename, allow_pickle=True) print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( time.time() - start_time)) print_rank_0(' total number of samples: {}'.format( samples_mapping.shape[0])) return samples_mapping