# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # copied from fairseq/fairseq/data/indexed_dataset.py # Removed IndexedRawTextDataset since it relied on Fairseq dictionary # other slight modifications to remove fairseq dependencies # Added document index to index file and made it accessible. # An empty sentence no longer separates documents. from functools import lru_cache import os import stat import shutil import struct from itertools import accumulate import numpy as np import torch from megatron import print_rank_0 def best_fitting_dtype(vocab_size=None): if vocab_size is not None and vocab_size < 65500: return np.uint16 else: return np.int32 def get_available_dataset_impl(): return ['lazy', 'cached', 'mmap'] def infer_dataset_impl(path): if IndexedDataset.exists(path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) if magic == IndexedDataset._HDR_MAGIC: return 'cached' elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: return 'mmap' else: return None else: print(f"Dataset does not exist: {path}") print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None def make_builder(out_file, impl, dtype=None): if impl == 'mmap': assert dtype is not None return MMapIndexedDatasetBuilder(out_file, dtype=dtype) else: assert dtype is None return IndexedDatasetBuilder(out_file) def make_dataset(path, impl, skip_warmup=False): if not IndexedDataset.exists(path): print(f"Dataset does not exist: {path}") print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") return None if impl == 'infer': impl = infer_dataset_impl(path) if impl == 'lazy' and IndexedDataset.exists(path): return IndexedDataset(path) elif impl == 'cached' and IndexedDataset.exists(path): return IndexedCachedDataset(path) elif impl == 'mmap' and MMapIndexedDataset.exists(path): return MMapIndexedDataset(path, skip_warmup) print(f"Unknown dataset implementation: {impl}") return None def dataset_exists(path, impl): if impl == 'mmap': return MMapIndexedDataset.exists(path) else: return IndexedDataset.exists(path) def read_longs(f, n): a = np.empty(n, dtype=np.int64) f.readinto(a) return a def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16 } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) def index_file_path(prefix_path): return prefix_path + '.idx' def data_file_path(prefix_path): return prefix_path + '.bin' def create_doc_idx(sizes): doc_idx = [0] for i, s in enumerate(sizes): if s == 0: doc_idx.append(i + 1) return doc_idx class IndexedDataset(torch.utils.data.Dataset): """Loader for IndexedDataset""" _HDR_MAGIC = b'TNTIDX\x00\x00' def __init__(self, path): super().__init__() self.path = path self.data_file = None self.read_index(path) def read_index(self, path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) assert magic == self._HDR_MAGIC, ( 'Index file doesn\'t match expected format. ' 'Make sure that --dataset-impl is configured properly.' ) version = f.read(8) assert struct.unpack('= self._len: raise IndexError('index out of range') def __del__(self): if self.data_file: self.data_file.close() # @lru_cache(maxsize=8) def __getitem__(self, idx): if not self.data_file: self.read_data(self.path) if isinstance(idx, int): i = idx self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) return a elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]] size = sum(sizes) a = np.empty(size, dtype=self.dtype) self.data_file.seek(self.data_offsets[start] * self.element_size) self.data_file.readinto(a) offsets = list(accumulate(sizes)) sents = np.split(a, offsets[:-1]) return sents def __len__(self): return self._len def num_tokens(self, index): return self.sizes[index] def size(self, index): return self.sizes[index] @staticmethod def exists(path): return ( os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) @property def supports_prefetch(self): return False # avoid prefetching to save memory class IndexedCachedDataset(IndexedDataset): def __init__(self, path): super().__init__(path) self.cache = None self.cache_index = {} @property def supports_prefetch(self): return True def prefetch(self, indices): if all(i in self.cache_index for i in indices): return if not self.data_file: self.read_data(self.path) indices = sorted(set(indices)) total_size = 0 for i in indices: total_size += self.data_offsets[i + 1] - self.data_offsets[i] self.cache = np.empty(total_size, dtype=self.dtype) ptx = 0 self.cache_index.clear() for i in indices: self.cache_index[i] = ptx size = self.data_offsets[i + 1] - self.data_offsets[i] a = self.cache[ptx: ptx + size] self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.readinto(a) ptx += size if self.data_file: # close and delete data file after prefetch so we can pickle self.data_file.close() self.data_file = None # @lru_cache(maxsize=8) def __getitem__(self, idx): if isinstance(idx, int): i = idx self.check_index(i) tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] a = np.empty(tensor_size, dtype=self.dtype) ptx = self.cache_index[i] np.copyto(a, self.cache[ptx: ptx + a.size]) return a elif isinstance(idx, slice): # Hack just to make this work, can optimizer later if necessary sents = [] for i in range(*idx.indices(len(self))): sents.append(self[i]) return sents class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, np.uint16: 2, np.int16: 2, np.int32: 4, np.int64: 8, np.float: 4, np.double: 8 } @staticmethod def write_header(fout, dtype, numdata, numsize, numdoc): """Writes header for cached indexed dataset to given file handle, return number of bytes written.""" startpos = fout.tell() fout.write(IndexedDataset._HDR_MAGIC) fout.write(struct.pack(' [0, 10, 30, 35] if arr.size > 1: arr[1:] = arr[:-1] if arr.size > 0: arr[0] = 0 def get_pointers_with_total(sizes, elemsize, dtype): """Return a numpy array of type np.dtype giving the byte offsets. Multiplies values in the sizes array by elemsize (bytes), and then computes an exclusive scan to get byte offsets. Returns the total number of bytes as second item in a tuple. """ # scale values in sizes array by elemsize to get sizes in bytes pointers = np.array(sizes, dtype=dtype) pointers *= elemsize np.cumsum(pointers, axis=0, out=pointers) # get total number of bytes from all sizes (last element) bytes_last = pointers[-1] if len(sizes) > 0 else 0 # convert to byte offsets exscan_from_cumsum_(pointers) return pointers, bytes_last class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' @staticmethod def write_header(fout, dtype, numsizes, numdocs): """Writes header for mmap indexed dataset to given file handle, return number of bytes written.""" startpos = fout.tell() fout.write(MMapIndexedDataset.Index._HDR_MAGIC) fout.write(struct.pack(' 0, "All ranks have no input files to merge" # Check that files are all of the same index type indexstr = gather_files_dist_check_impltype(filelist, distctx) # Concatenate the data files gather_files_dist_bin(filemain, filelist, distctx) # Combine index files into a single index file if indexstr == "cached": gather_files_dist_idx_cached(filemain, filelist, distctx) elif indexstr == "mmap": gather_files_dist_idx_mmap(filemain, filelist, distctx) def get_start_end(count, rank, numranks): """Return (start, end) index values for calling rank to evenly divide count items among numranks. Example usage: start, end = get_start_end(len(itemlist), distctx.rank, distctx.numranks) sublist = itemlist[start:end] Parameters ---------- count : int Total number of items to be divided rank : int Rank of the calling process, within range of [0, numranks) numranks : int Number of ranks by which to divide count items Returns ---------- (start, end) : tuple(int) Start and end index values that define the [start, end) range for rank """ num, remainder = divmod(count, numranks) if rank < remainder: start = (num + 1) * rank end = start + num + 1 else: start = (num + 1) * remainder + num * (rank - remainder) end = start + num return start, end def merge_files_dist(filemain, filelist, distctx): """Merge list of indexed datasets into a single indexed dataset named in filemain. Given a list of indexed datasets in filelist, and the set of processes defined by the distributed environment in distctx, collectively merge files into a new, single output indexed dataset named in filemain. This overwrites filemain if it already exists. It does not delete the input datasets in filelist. The input parameters filemain and filelist must be identical on all calling processes, and all processes in distctx must call this method collectively. It requires that all ranks be able to read any file in filelist, and all ranks must be able to write to the single output file named in filemain.""" # TODO: if file sizes vary significantly, it might be better to consider # file size when splitting the list to different ranks. # evenly divide list of files among ranks start, end = get_start_end(len(filelist), distctx.rank, distctx.numranks) sublist = filelist[start:end] # delegate merge to gather implementation return gather_files_dist(filemain, sublist, distctx)