# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # Essentially re-written in entirety import logging import os import shutil import struct import time from enum import Enum from functools import lru_cache from itertools import accumulate from types import TracebackType from typing import List, Optional, Tuple, Type, Union import numpy import torch from megatron.core.utils import log_single_rank logger = logging.getLogger(__name__) _INDEX_HEADER = b"MMIDIDX\x00\x00" class DType(Enum): """The NumPy data type Enum for writing/reading the IndexedDataset indices """ uint8 = 1 int8 = 2 int16 = 3 int32 = 4 int64 = 5 float64 = 6 float32 = 7 uint16 = 8 @classmethod def code_from_dtype(cls, value: Type[numpy.number]) -> int: """Get the code from the dtype Args: value (Type[numpy.number]): The dtype Returns: int: The code """ return cls[value.__name__].value @classmethod def dtype_from_code(cls, value: int) -> Type[numpy.number]: """Get the dtype from the code Args: value (int): The code Returns: Type[numpy.number]: The dtype """ return getattr(numpy, cls(value).name) @staticmethod def size(key: Union[int, Type[numpy.number]]) -> int: """Get the size of the dtype/code in bytes Args: key (Union[int, Type[numpy.number]]): The dtype or code Raises: ValueError: If the key is neither dtype nor integer code Returns: int: The size of the dtype/code in in bytes """ if isinstance(key, int): return DType.dtype_from_code(key)().itemsize elif numpy.number in key.__mro__: return key().itemsize else: raise ValueError @staticmethod def optimal_dtype(cardinality: Optional[int]) -> Type[numpy.number]: """Get the dtype to use for an index of a certain cardinality Args: cardinality (Optional[int]): The number of elements to be indexed Returns: Type[numpy.number]: The dtype to use for the index """ if cardinality is not None and cardinality < 65500: return numpy.uint16 else: return numpy.int32 class _IndexWriter(object): """Object class to write the index (.idx) file Args: idx_path (str): The path to the index file dtype (Type[numpy.number]): The dtype of the index file """ def __init__(self, idx_path: str, dtype: Type[numpy.number]) -> None: self.idx_path = idx_path self.dtype = dtype def __enter__(self) -> "_IndexWriter": """Enter the context introduced by the 'with' keyword Returns: _IndexWriter: The instance """ self.idx_writer = open(self.idx_path, "wb") # fixed, vestigial practice self.idx_writer.write(_INDEX_HEADER) # fixed, vestigial practice self.idx_writer.write(struct.pack(" Optional[bool]: """Exit the context introduced by the 'with' keyword Args: exc_type (Optional[Type[BaseException]]): Exception type exc_val (Optional[BaseException]): Exception value exc_tb (Optional[TracebackType]): Exception traceback object Returns: Optional[bool]: Whether to silence the exception """ self.idx_writer.close() def write( self, sequence_lengths: List[int], sequence_modes: Optional[List[int]], document_indices: List[int], ) -> None: """Write the index (.idx) file Args: sequence_lengths (List[int]): The length of each sequence sequence_modes (Optional[List[int]]): The mode of each sequences document_indices (List[int]): The seqyebce indices demarcating the end of each document """ sequence_pointers = self._sequence_pointers(sequence_lengths) # the number of sequences in the dataset sequence_count = len(sequence_lengths) self.idx_writer.write(struct.pack(" List[int]: """Build the sequence pointers per the sequence lengths and dtype size Args: sequence_lengths (List[int]): The length of each sequence Returns: List[int]: The pointer to the beginning of each sequence """ itemsize = DType.size(self.dtype) curr_ptr = 0 list_ptr = [] for length in sequence_lengths: list_ptr.append(curr_ptr) curr_ptr += length * itemsize return list_ptr class _IndexReader(object): """Object class to read the index (.idx) file Args: idx_path (str): The path to the index file multimodal (bool): Whether the dataset is multimodal """ def __init__(self, idx_path: str, multimodal: bool) -> None: log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} from {idx_path}") with open(idx_path, "rb") as stream: header = stream.read(9) assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}" version = struct.unpack(" time elapsed: {t_end - t_beg:4f} seconds") log_single_rank(logger, logging.INFO, f"\tExtract the sequence pointers") t_beg = time.time() self.sequence_pointers = numpy.frombuffer( self.bin_buffer, dtype=numpy.int64, count=self.sequence_count, offset=offset + self.sequence_lengths.nbytes, ) t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") log_single_rank(logger, logging.INFO, f"\tExtract the document indices") t_beg = time.time() self.document_indices = numpy.frombuffer( self.bin_buffer, dtype=numpy.int64, count=self.document_count, offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes, ) t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") self.sequence_modes = None if multimodal: log_single_rank(logger, logging.INFO, f"\tExtract the sequence modes") t_beg = time.time() self.sequence_modes = numpy.frombuffer( self.bin_buffer, dtype=numpy.int8, count=self.sequence_count, offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes + self.document_indices.nbytes, ) t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") assert self.sequence_lengths.shape[0] == len(self) assert self.sequence_lengths.shape[0] == self.sequence_count assert self.sequence_lengths.shape[0] == self.document_indices[-1] log_single_rank(logger, logging.INFO, f"> total number of sequences: {len(self)}") log_single_rank( logger, logging.INFO, f"> total number of documents: {self.document_indices.shape[0] - 1}", ) def __del__(self) -> None: """Clean up the object """ if hasattr(self, "bin_buffer_mmap"): self.bin_buffer_mmap._mmap.close() del self.bin_buffer_mmap def __len__(self) -> int: """Return the length of the dataset Returns: int: The length of the dataset """ return self.sequence_count @lru_cache(maxsize=8) def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: """Return the pointer, length, and mode at the index Args: idx (int): The index into the dataset Returns: Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at the index """ return ( self.sequence_pointers[idx], self.sequence_lengths[idx], self.sequence_modes[idx] if self.sequence_modes is not None else None, ) class IndexedDataset(torch.utils.data.Dataset): """The low-level interface dataset class Args: path_prefix (str): The index (.idx) and data (.bin) prefix multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False. mmap (bool, optional): Whether to mmap the .bin files. Defaults to True. """ def __init__(self, path_prefix: str, multimodal: bool = False, mmap: bool = True) -> None: super().__init__() self.path_prefix = None self.multimodal = None self.mmap = None self.initialize(path_prefix, multimodal, mmap) def initialize(self, path_prefix: str, multimodal: bool, mmap: bool) -> None: """Initialize the dataset This method is called by IndexedDataset.__init__ during object creation and by IndexedDataset.__setstate__ during un-pickling Args: path_prefix (str): The index (.idx) and data (.bin) prefix multimodal (bool): Whether the dataset is multimodal mmap (bool): Whether to mmap the .bin file """ idx_path = get_idx_path(path_prefix) bin_path = get_bin_path(path_prefix) assert os.path.exists(idx_path) and os.path.exists( bin_path ), f"One or both of the .idx and .bin files cannot be found at the path prefix {path_prefix}" self.path_prefix = path_prefix self.multimodal = multimodal self.mmap = mmap self.index = _IndexReader(idx_path, self.multimodal) self.bin_buffer = None self.bin_buffer_mmap = None if mmap: self.bin_buffer_mmap = numpy.memmap(bin_path, mode="r", order="C") self.bin_buffer = memoryview(self.bin_buffer_mmap) def __getstate__(self) -> Tuple[str, bool, bool]: """Get the state during pickling Returns: Tuple[str, bool, bool]: The state tuple """ return self.path_prefix, self.multimodal, self.mmap def __setstate__(self, state: Tuple[str, bool, bool]) -> None: """Set the state during un-pickling Args: state (Tuple[str, bool, bool]): The state tuple """ path_prefix, multimodal, mmap = state self.initialize(path_prefix, multimodal, mmap) def __del__(self) -> None: """Clean up the object """ if self.bin_buffer_mmap is not None: self.bin_buffer_mmap._mmap.close() del self.bin_buffer_mmap del self.index def __len__(self) -> int: """Return the length of the dataset i.e. the number of sequences in the index Returns: int: The length of the dataset """ return len(self.index) def _getitem_mmap( self, idx: Union[int, numpy.integer, slice] ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: """Return from the dataset by mmap-ing .bin file Args: idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset Raises: ValueError: When the index slice is non-contiguous TypeError: When the index is of an unexpected type Returns: Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index or index slice """ if isinstance(idx, (int, numpy.integer)): sequence_pointer, sequence_length, sequence_mode = self.index[idx] sequence = numpy.frombuffer( self.bin_buffer, dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer, ) return (sequence, sequence_mode) if sequence_mode is not None else sequence elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: raise ValueError("Slices into indexed_dataset must be contiguous") sequence_lengths = self.index.sequence_lengths[idx] sequence_modes = self.index.sequence_modes[idx] if self.multimodal else None sequence_offsets = list(accumulate(sequence_lengths)) sequences = numpy.split( numpy.frombuffer( self.bin_buffer, dtype=self.index.dtype, count=sum(sequence_lengths), offset=self.index.sequence_pointers[start], ), sequence_offsets[:-1], ) return (sequences, sequence_modes) if sequence_modes is not None else sequences else: raise TypeError("Unexpected type received for idx: {}".format(type(idx))) def _getitem_file( self, idx: Union[int, numpy.integer, slice] ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: """Return from the dataset by using file pointer Args: idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset Raises: ValueError: When the index slice is non-contiguous TypeError: When the index is of an unexpected type Returns: Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index or index slice """ if isinstance(idx, (int, numpy.integer)): sequence_pointer, sequence_length, sequence_mode = self.index[idx] sequence = numpy.empty(sequence_length, dtype=self.index.dtype) with open(get_bin_path(self.path_prefix), mode='rb', buffering=0) as bin_buffer_file: bin_buffer_file.seek(sequence_pointer) bin_buffer_file.readinto(sequence) return (sequence, sequence_mode) if sequence_mode is not None else sequence elif isinstance(idx, slice): assert False, "slicing not implemented without mmap" else: raise TypeError("Unexpected type received for idx: {}".format(type(idx))) def __getitem__( self, idx: Union[int, numpy.integer, slice] ) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: """Return from the dataset Args: idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset Raises: ValueError: When the index slice is non-contiguous TypeError: When the index is of an unexpected type Returns: Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index or index slice """ if self.bin_buffer_mmap is not None: return self._getitem_mmap(idx) else: return self._getitem_file(idx) def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray: """Retrieve a single item from the dataset with the option to only return a portion of the item. get(idx) is the same as [idx] but get() does not support slicing. Args: idx (Union[int, numpy.integer]): The index into the dataset offset (int): The integer token offset in the sequence length (int): The number of tokens to grab from the sequence Returns: Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index """ sequence_pointer, sequence_length, sequence_mode = self.index[idx] if length is None: length = sequence_length - offset sequence_pointer += offset * DType.size(self.index.dtype) if self.bin_buffer: sequence = numpy.frombuffer( self.bin_buffer, dtype=self.index.dtype, count=length, offset=sequence_pointer ) else: sequence = numpy.empty(length, dtype=self.index.dtype) with open(get_bin_path(self.path_prefix), mode='rb', buffering=0) as bin_buffer_file: bin_buffer_file.seek(sequence_pointer) bin_buffer_file.readinto(sequence) return (sequence, sequence_mode) if sequence_mode is not None else sequence @property def sequence_lengths(self) -> numpy.ndarray: """Get the sequence lengths Returns: numpy.ndarray: The sequence lengths """ return self.index.sequence_lengths @property def document_indices(self) -> numpy.ndarray: """Get the document indices Returns: numpy.ndarray: The document indices """ return self.index.document_indices def get_document_indices(self) -> numpy.ndarray: """Get the document indices This method is slated for deprecation. Returns: numpy.ndarray: The document indices """ return self.index.document_indices def set_document_indices(self, document_indices: numpy.ndarray) -> None: """Set the document indices This method is slated for deprecation. Args: document_indices (numpy.ndarray): The document indices """ self.index.document_indices = document_indices @property def sequence_modes(self) -> numpy.ndarray: """Get the sequence modes Returns: numpy.ndarray: The sequence modes """ return self.index.sequence_modes @staticmethod def exists(path_prefix: str) -> bool: """Return whether the IndexedDataset exists on disk at the prefix Args: path_prefix (str): The prefix to the index (.idx) and data (.bin) files Returns: bool: Whether the IndexedDataset exists on disk at the prefix """ return os.path.exists(get_idx_path(path_prefix)) and os.path.exists( get_bin_path(path_prefix) ) class IndexedDatasetBuilder(object): """Builder class for the IndexedDataset class Args: bin_path (str): The path to the data (.bin) file dtype (Type[numpy.number], optional): The dtype of the index file. Defaults to numpy.int32. multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False. """ def __init__( self, bin_path: str, dtype: Type[numpy.number] = numpy.int32, multimodal: bool = False ) -> None: self.data_file = open(bin_path, "wb") self.dtype = dtype self.multimodal = multimodal self.sequence_lengths = [] self.document_indices = [0] self.sequence_modes = [] if self.multimodal else None def add_item(self, tensor: torch.Tensor, mode: int = 0) -> None: """Add a single item to the dataset Args: tensor (torch.Tensor): The item to add to the data file mode (int, optional): The mode for the item. Defaults to 0. """ np_array = numpy.array(tensor.numpy(), dtype=self.dtype) self.data_file.write(np_array.tobytes(order="C")) self.sequence_lengths.append(np_array.size) if self.multimodal: self.sequence_modes.append(mode) def add_document( self, tensor: torch.Tensor, lengths: List[int], modes: Optional[List[int]] = None ) -> None: """Add an entire document to the dataset Args: tensor (torch.Tensor): The document to add lengths (List[int]): The lengths of each item in the document modes (Optional[List[int]], optional): The modes for each item in the document. Defaults to None. """ np_array = numpy.array(tensor, dtype=self.dtype) self.data_file.write(np_array.tobytes(order="C")) self.sequence_lengths.extend(lengths) self.document_indices.append(len(self.sequence_lengths)) if self.multimodal: self.sequence_modes.extend(modes if modes is not None else [0] * lengths) def end_document(self) -> None: """Finalize the document, for use with IndexedDatasetBuilder.add_item """ self.document_indices.append(len(self.sequence_lengths)) def add_index(self, path_prefix: str) -> None: """Add an entire IndexedDataset to the dataset Args: path_prefix (str): The index (.idx) and data (.bin) prefix """ # Concatenate index index = _IndexReader(get_idx_path(path_prefix), multimodal=self.multimodal) assert index.dtype == self.dtype offset = len(self.sequence_lengths) self.sequence_lengths.extend(index.sequence_lengths) self.document_indices.extend((offset + index.document_indices)[1:]) if self.multimodal: self.sequence_modes.extend(index.sequence_modes) # Concatenate data with open(get_bin_path(path_prefix), "rb") as f: shutil.copyfileobj(f, self.data_file) def finalize(self, idx_path: str) -> None: """Clean up and write the index (.idx) file Args: idx_path (str): The path to the index file """ self.data_file.close() with _IndexWriter(idx_path, self.dtype) as writer: writer.write(self.sequence_lengths, self.sequence_modes, self.document_indices) def get_idx_path(path_prefix: str) -> str: """Get the path to the index file from the prefix Args: path_prefix (str): The prefix Returns: str: The path to the index file """ return path_prefix + ".idx" def get_bin_path(path_prefix: str) -> str: """Get the path to the data file from the prefix Args: path_prefix (str): The prefix Returns: str: The path to the data file """ return path_prefix + ".bin"