Commit 4e867b3c authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
import math
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Iterable, List, Optional, Type, Union
import numpy
import torch
from megatron.core.datasets.blended_dataset import BlendedDataset
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.megatron_dataset import LowLevelDataset, MegatronDataset
from megatron.core.datasets.utils import Split, normalize
from megatron.core.parallel_state import get_virtual_pipeline_model_parallel_rank
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
MidLevelDataset = MegatronDataset
TopLevelDataset = Union[BlendedDataset, MidLevelDataset]
DistributedDataset = Union[
TopLevelDataset, MidLevelDataset, LowLevelDataset, torch.utils.data.Dataset
]
class BlendedMegatronDatasetBuilder(object):
"""Builder class for the BlendedDataset and MegatronDataset classes
Args:
cls (Type[MegatronDataset]): The class to instantiate, must inherit from MegatronDataset
sizes (List[Optional[int]]): The minimum total number of samples to draw, or None, per split
is_built_on_rank (Callable): A callable which returns True if the dataset should be built on
the current rank and False otherwise. It should be Megatron Core parallelism aware i.e.
global rank, local group rank, and virtual rank may inform its return value.
config (BlendedMegatronDatasetConfig): The config object which informs dataset creation
"""
def __init__(
self,
cls: Type[MidLevelDataset],
sizes: List[int],
is_built_on_rank: Callable,
config: BlendedMegatronDatasetConfig,
):
self.cls = cls
self.sizes = sizes
self.is_built_on_rank = is_built_on_rank
self.config = config
log_single_rank(
logger,
logging.INFO,
f"Building {cls.__name__} splits with sizes={self.sizes} and config={self.config}",
)
if not self.config.mock:
for split in Split:
size_is_none = self.sizes[split.value] is None
if self.config.blend_per_split is None:
weights_are_none = self.config.blend[1] is None
else:
if self.config.blend_per_split[split.value] is None:
continue
weights_are_none = self.config.blend_per_split[split.value][1] is None
if size_is_none:
assert (
weights_are_none
), f"size_is_none => weights_are_none fails for {split.name} split"
if torch.distributed.is_initialized():
gb_rank = torch.distributed.get_rank()
vp_rank = get_virtual_pipeline_model_parallel_rank()
if gb_rank == 0 and (vp_rank == 0 or vp_rank is None):
assert (
self.is_built_on_rank()
), "is_built_on_rank must return True when global rank = 0 and vp rank = 0"
def build(self) -> List[Optional[TopLevelDataset]]:
"""Build all dataset splits according to the provided blend(s)
This method is distributed-aware and must be called on all ranks.
The dataset splits returned can vary according to the config. Supply config.blend and
config.split to build BlendedDataset and/or MegatronDataset splits from the same
distribution. Supply config.blend_per_split to build BlendedDataset and/or MegatronDataset
splits from separate distributions. In either case, for each split, handle the following
cases:
(1) The split is None
- do nothing
(2) The split has one contributing dataset, and...
(a) 'size' is not None
- Build a mid-level dataset with low-level dataset sampling in proportion to the
size
(b) 'size' is None
- Build mid-level datasets with no excess low-level dataset sampling
(3) The split has multiple contributing datasets, and...
(a) 'weights' is not None and 'size' is not None
- Build mid-level datasets with low-level dataset sampling in proportion to their
weights and the size
- Build a top-level dataset of length marginally greater than 'size' with mid-level
dataset sampling in proportion to their weights and the size
(b) 'weights' is not None and 'size' is None
- Error
(c) 'weights' is None and 'size' is not None
- Build mid-level datasets with no excess low-level dataset sampling
- Build a top-level dataset of length 'size' (capped at the sum of the mid-level
dataset lengths) with mid-level dataset sampling in proportion to their lengths
and the size
(d) 'weights' is None and 'size' is None
- Build mid-level datasets with no excess low-level dataset sampling
- Build a top-level dataset with no excess mid-level dataset sampling
Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per
split
"""
datasets = self._build_blended_dataset_splits()
for dataset in datasets:
if dataset is not None and len(dataset) > 0:
if isinstance(dataset, BlendedDataset):
if dataset.built_anew_on_cache_miss or any(
x.built_anew_on_cache_miss for x in dataset.datasets
):
log_single_rank(
logger,
logging.INFO,
(
f"Verifying NumPy indices for {type(dataset).__name__} "
f"{dataset.split.name} split"
),
)
else:
log_single_rank(
logger,
logging.INFO,
(
f"NumPy indices for {type(dataset).__name__} {dataset.split.name} "
f"split are fully cached, skipping verification"
),
)
continue
# Check blend size
assert dataset.size is None or dataset.size == dataset.dataset_index.shape[0]
# Check blend access of mid-level datasets
dataset_indices, dataset_sizes = numpy.unique(
dataset.dataset_index, return_counts=True
)
for i, (index, size) in enumerate(zip(dataset_indices, dataset_sizes)):
if len(dataset.datasets[index]) < size:
raise IndexError(
f"The {dataset.split.name} blend oversamples the contributing "
f"datasets and, e.g., requests {size} samples from "
f"{type(dataset.datasets[index]).__name__} {i} with size "
f"{len(dataset.datasets[index])}. This is unexpected. "
f"Please file an issue."
)
return datasets
def _build_blended_dataset_splits(self) -> List[Optional[TopLevelDataset]]:
"""Build all dataset splits according to the provided blend(s)
See the BlendedMegatronDatasetBuilder.build alias for more information.
Returns:
List[Optional[TopLevelDataset]]: A list containing a dataset instance (or None) per
split
"""
##
# Return fake "mock" datasets
##
if self.config.mock:
split = self.config.split_matrix
try:
return self._build_megatron_dataset_splits(None, split, self.sizes)
except Exception as error:
raise Exception(
f"{self.cls.__name__} failed to build as a mock data generator"
) from error
##
# All splits come from the same distribution
##
elif self.config.blend:
prefixes, weights = self.config.blend
if weights is not None:
weights = normalize(weights)
split = self.config.split_matrix
# Blend consists of a single prefix
if len(prefixes) == 1 and weights is None:
return self._build_megatron_dataset_splits(prefixes[0], split, self.sizes)
# Build the mid-level datasets
if weights is None:
# Build only one "epoch"
sizes_per_dataset_buffer = [[None for split in Split] for prefix in prefixes]
else:
# The number of samples we plan to use per dataset
sizes_per_dataset_target = _get_size_per_split_per_dataset(weights, self.sizes)
# The number of samples we plan to build per dataset
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, self.sizes, margin=0.5
)
# Build each dataset in parallel
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split, sizes_per_dataset_buffer
)
# Build the top-level datasets
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
if split[i] is not None:
weights_i = weights
if weights_i is not None and self.sizes[i] is not None:
# Blend according to client-specified weights and client-specified size
size_per_dataset = list(zip(*sizes_per_dataset_target))[i]
size_i = sum(size_per_dataset)
elif weights_i is None:
# Blend according to dataset sizes as-is and (maybe) client-specified size
try:
weights_i = [
len(megatron_dataset) for megatron_dataset in megatron_datasets[i]
]
except TypeError:
weights_i = [0 for _ in prefixes]
if self.sizes[i] is not None:
size_i = min(self.sizes[i], sum(weights_i))
else:
# Build exhaustive indices
size_i = None
else:
raise ValueError(
"Using client-specified weights requires client-specified size"
)
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True, # synchronize_ranks, default behavior to build on rank-0 first
megatron_datasets[i],
weights_i,
size_i,
self.config,
)
return blended_datasets
##
# Each split comes from a separate distribution
##
else:
blended_datasets = [None] * len(Split)
for i in range(len(Split)):
split_spoof = [None] * len(Split)
split_spoof[i] = (0.0, 1.0)
sizes_spoof = [0] * len(Split)
sizes_spoof[i] = self.sizes[i]
# Blend is provided for the split
blend = self.config.blend_per_split[i]
if blend is not None:
prefixes, weights = blend
if weights is not None:
weights = normalize(weights)
# Blend consists of a sigle prefix
if len(prefixes) == 1:
blended_datasets[i] = self._build_megatron_dataset_splits(
prefixes[0], split_spoof, sizes_spoof
)[i]
continue
# Build mid-level datasets
if weights is None:
sizes_per_dataset_buffer = [
[None for split in Split] for prefix in prefixes
]
else:
# The number of samples we plan to use per dataset
sizes_per_dataset_target = _get_size_per_split_per_dataset(
weights, sizes_spoof
)
# The number of samples we plan to build per dataset
sizes_per_dataset_buffer = _get_size_per_split_per_dataset(
weights, sizes_spoof, margin=0.5
)
# Build each dataset in parallel
megatron_datasets = self._build_megatron_datasets_parallel(
prefixes, split_spoof, sizes_per_dataset_buffer
)[i]
# Build top-level dataset
if weights is not None and self.sizes[i] is not None:
# Blend according to client-specified weights and client-specified size
size_per_dataset = list(zip(*sizes_per_dataset_target))[i]
size = sum(size_per_dataset)
elif weights is None:
# Blend according to dataset sizes as-is and (maybe) client-specified size
try:
weights = [
len(megatron_dataset) for megatron_dataset in megatron_datasets
]
except TypeError:
weights = [0 for _ in prefixes]
if self.sizes[i] is not None:
size = min(self.sizes[i], sum(weights))
else:
# Build exhaustive indices
size = None
else:
raise RuntimeError
blended_datasets[i] = self.build_generic_dataset(
BlendedDataset,
self.is_built_on_rank,
True, # synchronize_ranks, default behavior to build on rank-0 first
megatron_datasets,
weights,
size,
self.config,
)
return blended_datasets
def _build_megatron_datasets_parallel(
self, prefixes: List[str], split: List[float], sizes_per_dataset: List[List[int]]
) -> List[List[Optional[MegatronDataset]]]:
"""Build the megatron datasets for a list of prefixes in parallel
Args:
prefixes (List[str]): The list of prefix strings
split (List[float]): The dataset split ratios (must sum to 1.00)
sizes_per_dataset (List[List[int]]): The number of samples to request
per MegatronDataset per spilt
Returns:
List[List[Optional[MegatronDataset]]]: For each split, have a list of
MegatronDataset per prefix
"""
# Helper function to wrap the threading logic
def _threading_helper(
megatron_datasets: List[List[Optional[MegatronDataset]]],
num_workers: int,
prefixes: List[str],
split: List[float],
sizes_per_dataset: List[List[int]],
) -> None:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
all_futures = []
for i in range(len(prefixes)):
all_futures.append(
executor.submit(
self._build_megatron_dataset_splits,
prefixes[i],
split,
sizes_per_dataset[i],
False, # synchronize_ranks, barrier is called in this function
)
)
for future in all_futures:
try:
megatron_datasets_split = future.result()
for j in range(len(megatron_datasets_split)):
megatron_datasets[j].append(megatron_datasets_split[j])
except Exception as err:
raise err
megatron_datasets = [[] for _ in range(len(Split))]
num_dataset_builder_threads = self.config.num_dataset_builder_threads
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
# First, build on rank 0
if rank == 0:
num_workers = num_dataset_builder_threads
if num_workers > 1:
# since only rank 0 is running, scale up the thread count
# but not too much to avoid overloading storage on miss path.
# if user set num_dataset_builder_threads to 1,
# i.e. meant for serial build, do not scale up.
num_workers *= min(2, max(1, torch.cuda.device_count()))
_threading_helper(
megatron_datasets, num_workers, prefixes, split, sizes_per_dataset
)
torch.distributed.barrier()
# Then, build on other ranks; guaranteed to be data_cache hit
if rank != 0:
_threading_helper(
megatron_datasets,
num_dataset_builder_threads,
prefixes,
split,
sizes_per_dataset,
)
else:
_threading_helper(
megatron_datasets, num_dataset_builder_threads, prefixes, split, sizes_per_dataset
)
return megatron_datasets
def _build_megatron_dataset_splits(
self,
dataset_path: Optional[str],
split: List[float],
sizes: List[int],
synchronize_ranks: bool = True,
) -> List[Optional[MidLevelDataset]]:
"""Build each MidLevelDataset split from a single LowLevelDataset
Args:
dataset_path (Optional[str]): The path on disk which defines the underlying
LowLevelDataset, or None for mock dataset classes
split (List[Tuple[float, float]]): The dataset split matrix
sizes (List[int]): The number of total samples to draw from each split
synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks
behavior. Set to False when we enforce this behavior at higher level.
Returns:
List[Optional[MidLevelDataset]]: The MidLevelDataset (or None) per split
"""
# short-cut if we are not building on this rank
if torch.distributed.is_initialized() and not self.is_built_on_rank():
for i in range(len(Split)):
if split[i] is not None and synchronize_ranks:
torch.distributed.barrier()
return [None] * len(Split)
# Build the low level dataset
low_level_dataset = self.cls.build_low_level_dataset(dataset_path, self.config)
# Build the split indices for the low level dataset
num_elements = self.cls.numel_low_level_dataset(low_level_dataset)
split_indices = []
for i, _ in enumerate(Split):
if split[i] is not None:
beg = int(round(split[i][0] * float(num_elements)))
end = int(round(split[i][1] * float(num_elements)))
split_indices.append(numpy.arange(start=beg, stop=end, step=1, dtype=numpy.int32))
else:
split_indices.append(None)
# Build the mid level dataset
mid_level_datasets = []
for i, _split in enumerate(Split):
if split[i] is None:
mid_level_datasets.append(None)
else:
mid_level_datasets.append(
self.build_generic_dataset(
self.cls,
self.is_built_on_rank,
synchronize_ranks,
low_level_dataset,
dataset_path,
split_indices[i],
sizes[i],
_split,
self.config,
)
)
return mid_level_datasets
@staticmethod
def build_generic_dataset(
cls: Union[Type[DistributedDataset], Callable],
is_built_on_rank: Callable,
synchronize_ranks: bool,
*args: Any,
) -> Optional[Union[DistributedDataset, Iterable]]:
"""Build the DistributedDataset
Return None if and only if the underlying dataset class is not built on the current rank
and torch.distributed is initialized.
Args:
cls (Union[Type[DistributedDataset], Callable]): The DistributedDataset class to be
built. In special cases, e.g. when we are building the low level dataset for a
RawMegatronDataset instance, we can accept a Callable which returns an Iterable.
synchronize_ranks (bool): Whether to call barrier for rank-0 / barrier / other-ranks
behavior. Set to False when we enforce this behavior at higher level.
args (Tuple[Any]): The positional arguments used to build the provided
DistributedDataset class
Raises:
Exception: When the dataset constructor raises an OSError
Returns:
Optional[Union[DistributedDataset, Iterable]]: The DistributedDataset instantion, the
Iterable instantiation, or None
"""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
dataset = None
# First, build on rank 0
if rank == 0 and is_built_on_rank():
try:
dataset = cls(*args)
except OSError as err:
log = (
f"Failed to write dataset materials to the data cache directory. Please "
f"supply a directory to which you have write access via the path_to_cache "
f"attribute in BlendedMegatronDatasetConfig and retry. Refer to the "
f"preserved traceback above for more information."
)
raise Exception(log) from err
if synchronize_ranks:
torch.distributed.barrier()
# After, build on other ranks
if rank != 0 and is_built_on_rank():
dataset = cls(*args)
return dataset
return cls(*args)
def _get_size_per_split_per_dataset(
normalized_weights: List[float], target_size_per_split: List[int], margin: float = 0.0
) -> List[List[int]]:
"""Determine the contribution of the MegatronDataset splits to the BlendedDataset splits
Args:
normalized_weights (List[float]): e.g. [0.3, 0.7]
target_size_per_split (List[int]): The number of samples to target for each BlendedDataset
split
margin (float): The relative quantity of extra samples to build per per split per dataset,
as a percentage
Returns:
List[List[int]]: The number of samples to request per MegatronDataset per split
"""
assert numpy.isclose(sum(normalized_weights), 1.0)
# Use margin as buffer to ensure we satiate the request
sizes_per_dataset = [
[
int(math.ceil(math.ceil(target_size * weight) * (1 + margin / 100)))
for target_size in target_size_per_split
]
for weight in normalized_weights
]
return sizes_per_dataset
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import functools
import logging
import re
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from megatron.core.datasets.utils import Split, log_single_rank, normalize
logger = logging.getLogger(__name__)
@dataclass
class BlendedMegatronDatasetConfig:
"""Configuration object for Megatron Core datasets"""
random_seed: int
"""The seed for all RNG during dataset creation."""
sequence_length: int
"""The sequence length."""
blend: Optional[Tuple[List[str], Optional[List[float]]]] = None
"""The blend, consisting of a list of dataset prefixes and optionally a list of dataset
weights. For example, [["dataset-path1", "dataset-path2"], [0.3, 0.7]]. When the weights are
None, they are inferred from the lengths of the contributing datasets. Not to be used with
'blend_per_split'. Defaults to None.
"""
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] = None
"""A set of blends, as defined above, one for each split distribution. Not to be used with
'blend'. Defauls to None.
"""
split: Optional[str] = None
"""The split string, a comma separated weighting for the dataset splits when drawing samples
from a single distribution. Not to be used with 'blend_per_split'. Defaults to None.
"""
split_matrix: Optional[List[Tuple[float, float]]] = field(init=False, default=None)
"""The split matrix consisting of non-overlapping book-ends of each split in order. For more
information, refer to 'convert_split_vector_to_split_matrix'. Created automatically from
'split'. Not to be passed in to the constructor.
"""
num_dataset_builder_threads: int = 1
"""The number of threads to use for dataset building."""
path_to_cache: Optional[str] = None
"""Where all re-useable dataset indices are to be cached."""
mmap_bin_files: bool = True
"""Whether to mmap the .bin files or use file pointers."""
mock: bool = field(init=False, default=False)
"""Whether to bypass real data loading and validation in favor of mock data generation.
Created automatically from 'blend' and 'blend_per_split'. Not to be passed in to the
constructor.
"""
tokenizer: Optional[MegatronTokenizer] = None
"""The MegatronTokenizer instance. Required for datasets that do online tokenization."""
def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
if self.blend_per_split is not None and any(self.blend_per_split):
assert self.blend is None, "blend and blend_per_split are incompatible"
assert self.split is None, "split and blend_per_split are incompatible"
assert len(self.blend_per_split) == len(
Split
), f"blend_per_split must contain {len(Split)} blends"
for split in Split:
if self.blend_per_split[split.value] is None:
log_single_rank(
logger, logging.INFO, f"blend not provided for {split.name} split"
)
else:
assert self.blend_per_split[split.value][1] is None or len(
self.blend_per_split[split.value][0]
) == len(
self.blend_per_split[split.value][1]
), "blend per split prefixes and weights must be equal in number"
else:
if self.blend is not None:
assert self.blend[1] is None or len(self.blend[0]) == len(
self.blend[1]
), "blend prefixes and weights must be equal in number"
assert self.split is not None, "split must be provided when blend is not None"
else:
self.mock = True
log_single_rank(
logger,
logging.INFO,
f"Let mock = True, as both blend and blend_per_split are None",
)
self.split = "1,1,1"
log_single_rank(
logger,
logging.INFO,
f"Let split = {self.split}, an arbitrarily even split, as mock is True",
)
split_vector = parse_and_normalize_split(self.split)
self.split_matrix = convert_split_vector_to_split_matrix(split_vector)
log_single_rank(logger, logging.INFO, f"Let split_matrix = {self.split_matrix}")
def parse_and_normalize_split(split: str) -> List[float]:
"""Parse the dataset split ratios from a string
Args:
split (str): The train valid test split string e.g. "99,1,0"
Returns:
List[float]: The trian valid test split ratios e.g. [0.99, 0.01, 0.0]
"""
split = list(map(float, re.findall(r"[.0-9]+", split)))
split = split + [0.0 for _ in range(len(Split) - len(split))]
assert len(split) == len(Split)
assert all(map(lambda _: _ >= 0.0, split))
split = normalize(split)
return split
def convert_split_vector_to_split_matrix(
vector_a: List[float], vector_b: Optional[List[float]] = None
) -> List[Optional[Tuple[float, float]]]:
"""Build the split matrix from one or optionally two contributing split vectors.
Ex. a standard conversion:
[0.99, 0.01, 0.0] -> [(0, 0.99), (0.99, 1.0), None]
Ex. a conversion for Retro when Retro pretraining uses a [0.99, 0.01, 0.0] split and Retro
preprocessing used a [0.98, 0.02, 0.0] split:
[0.99, 0.01, 0.0], [0.98, 0.02, 0.0] -> [(0, 0.98), (0.99, 1.0), None]
Args:
vector_a (List[float]): The primary split vector
vector_b (Optional[List[float]]): An optional secondary split vector which constrains the
primary split vector. Defaults to None.
Returns:
List[Tuple[float, float]]: The split matrix consisting of book-ends of each split in order
"""
if vector_b is None:
vector_b = vector_a
# [.900, .090, .010] -> [0.00, .900, .990, 100]
expansion_a = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_a])
expansion_b = functools.reduce(lambda a, b: a + [a[len(a) - 1] + b], [[0], *vector_b])
# [0.00, .900, .990, 100.0] -> [(0.00, .900), (.900, .990), (.990, 100)]
bookends_a = list(zip(expansion_a[:-1], expansion_a[1:]))
bookends_b = list(zip(expansion_b[:-1], expansion_b[1:]))
# gather per-split overlap or None
matrix = []
for bookend_a, bookend_b in zip(bookends_a, bookends_b):
if min(bookend_a[1], bookend_b[1]) <= max(bookend_a[0], bookend_b[0]):
overlap = None
else:
overlap = (max(bookend_a[0], bookend_b[0]), min(bookend_a[1], bookend_b[1]))
matrix.append(overlap)
return matrix
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import logging
import os
import time
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import numpy
import torch
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.megatron_dataset import MegatronDataset
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from megatron.core.datasets.utils import Split
from megatron.core.datasets.utils_s3 import S3Config, is_s3_path
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
_PAD_TOKEN_ID = -1
@dataclass
class GPTDatasetConfig(BlendedMegatronDatasetConfig):
"""Configuration object for Megatron Core GPT datasets"""
reset_position_ids: bool = None
"""Option to reset the position IDs in the dataset at an interval"""
reset_attention_mask: bool = None
"""Option to reset the attention mask from the dataset"""
eod_mask_loss: bool = None
"""Option to enable the EOD mask loss"""
create_attention_mask: bool = True
"""Option to enable the attention masks generation. Can be disabled if attention kernel
generates masks by itself.
"""
drop_last_partial_validation_sequence: bool = True
"""Option to drop the last partial validation sequence"""
add_extra_token_to_sequence: bool = True
"""Option to draw sequences with one extra token to ensure the sample input tokens and sample
output tokens are both of the desired sequence length
"""
s3_cache_path: str = None
"""Path for caching indices for s3 dataloading."""
def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
assert self.tokenizer is not None
assert self.reset_position_ids is not None
assert self.reset_attention_mask is not None
assert self.eod_mask_loss is not None
class GPTDataset(MegatronDataset):
"""The base GPT dataset
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the GPTDataset
dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When
None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (GPTDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: Optional[str],
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: GPTDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.masks_and_position_ids_are_cacheable = not any(
[
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
]
)
self.masks_and_position_ids_are_cached = False
self.cached_attention_mask = None
self.cached_loss_mask = None
self.cached_position_ids = None
try:
self._pad_token_id = self.config.tokenizer.pad
except Exception:
self._pad_token_id = _PAD_TOKEN_ID
(self.document_index, self.sample_index, self.shuffle_index) = (
self._build_document_sample_shuffle_indices()
)
@staticmethod
def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int:
"""Abstract method implementation
For GPT, the underlying IndexedDataset should be split by sequence, as opposed to, say,
BERT, which should be split by document
Args:
low_level_dataset (IndexedDataset): The underlying IndexedDataset
Returns:
int: The number of unique elements in the underlying IndexedDataset
"""
return low_level_dataset.sequence_lengths.shape[0]
@staticmethod
def build_low_level_dataset(dataset_path: str, config: GPTDatasetConfig) -> IndexedDataset:
"""Abstract method implementation
Args:
dataset_path (str): The real path prefix to the IndexedDataset .bin and .idx files
config (GPTDatasetConfig): The config
Returns:
IndexedDataset: The underlying IndexedDataset
"""
if is_s3_path(dataset_path):
return IndexedDataset(
dataset_path,
multimodal=False,
mmap=config.mmap_bin_files,
s3_config=S3Config(path_to_idx_cache=config.s3_cache_path),
)
return IndexedDataset(dataset_path, multimodal=False, mmap=config.mmap_bin_files)
def __len__(self) -> int:
"""Abstract method implementation
Returns:
int: The length of the dataset
"""
return self.sample_index.shape[0] - 1
def __getitem__(self, idx: Optional[int]) -> Dict[str, torch.Tensor]:
"""Abstract method implementation
Args:
idx (Optioal[int]): The index into the dataset
Returns:
Dict[str, torch.Tensor]: The sample information wrapped in a dictionary
"""
if idx is None:
# Batch padding sequence so the index does not matter
text, _ = self._query_document_sample_shuffle_indices(0)
else:
text, _ = self._query_document_sample_shuffle_indices(idx)
text = torch.from_numpy(text).long()
if self.config.add_extra_token_to_sequence:
tokens = text[:-1].contiguous()
labels = text[1:].contiguous()
else:
tokens = text
labels = torch.roll(text, shifts=-1, dims=0)
labels[-1] = self._pad_token_id
if (
not self.masks_and_position_ids_are_cacheable
or not self.masks_and_position_ids_are_cached
):
attention_mask, loss_mask, position_ids = _get_ltor_masks_and_position_ids(
tokens,
self.config.tokenizer.eod,
self.config.reset_position_ids,
self.config.reset_attention_mask,
self.config.eod_mask_loss,
self.config.create_attention_mask,
)
if self.masks_and_position_ids_are_cacheable:
self.cached_attention_mask = attention_mask
self.cached_loss_mask = loss_mask
self.cached_position_ids = position_ids
self.masks_and_position_ids_are_cached = True
else:
attention_mask = self.cached_attention_mask
loss_mask = self.cached_loss_mask
position_ids = self.cached_position_ids
# For padded sequences, mask the loss
loss_mask[labels == self._pad_token_id] = 0.0
# For padded sequences, ensure the embedding layer can map the token ID
tokens[tokens == self._pad_token_id] = 0
labels[labels == self._pad_token_id] = 0
# Batch padding sequence so we mask the loss
if idx is None:
loss_mask = torch.zeros_like(loss_mask)
if self.config.create_attention_mask:
return {
"tokens": tokens,
"labels": labels,
"attention_mask": attention_mask,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
else:
return {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"position_ids": position_ids,
}
def _query_document_sample_shuffle_indices(
self, idx: int
) -> Tuple[numpy.ndarray, numpy.ndarray]:
"""Get the text (token ids) and document ids for a given index
Args:
idx (int): The index into the dataset
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The text ids and document ids
"""
# Do the shuffle mapping
idx = self.shuffle_index[idx]
# Get the beginning and end documents and offsets
doc_index_beg, doc_index_beg_offset = self.sample_index[idx]
doc_index_end, doc_index_end_offset = self.sample_index[idx + 1]
document_ids = []
sample_parts = []
# Sample spans a single document
if doc_index_beg == doc_index_end:
# Add the document id
document_ids.append(self.document_index[doc_index_beg])
# Add the entire sample
sample_parts.append(
self.dataset.get(
self.document_index[doc_index_beg],
offset=doc_index_beg_offset,
length=doc_index_end_offset
- doc_index_beg_offset
+ self.config.add_extra_token_to_sequence,
)
)
# Sample spans multiple documents
else:
for i in range(doc_index_beg, doc_index_end + 1):
# Add the document id
document_ids.append(self.document_index[i])
# Add the sample part
offset = 0 if i > doc_index_beg else doc_index_beg_offset
length = (
None
if i < doc_index_end
else doc_index_end_offset + self.config.add_extra_token_to_sequence
)
sample_parts.append(
self.dataset.get(self.document_index[i], offset=offset, length=length)
)
assert len(document_ids) == len(
sample_parts
), f"len(document_ids) ({len(document_ids)}) != len(sample_parts) ({len(sample_parts)})"
length = sum(map(len, sample_parts))
# Pad the sample if necessary
if length < (self.config.sequence_length + self.config.add_extra_token_to_sequence):
sample_parts.append(
[self._pad_token_id]
* (self.config.sequence_length + self.config.add_extra_token_to_sequence - length)
)
return (
numpy.concatenate(sample_parts, dtype=numpy.int64),
numpy.array(document_ids, dtype=numpy.int64),
)
def _build_document_sample_shuffle_indices(
self,
) -> Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]:
"""Build the document index, the sample index, and the shuffle index
The document index:
-- 1-D
-- An ordered array of document ids
The sample index:
-- 2-D
-- The document indices and offsets which mark the start of every sample
The shuffle index:
-- 1-D
-- A random permutation of index range of the sample index
Returns:
Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample
index, and the shuffle index
"""
path_to_cache = self.config.path_to_cache
if path_to_cache is None and not self.config.mock:
path_to_cache = os.path.join(
self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
)
if path_to_cache:
base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}"
get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}")
path_to_description = get_path_to("description.txt")
path_to_document_index = get_path_to("document_index.npy")
path_to_sample_index = get_path_to("sample_index.npy")
path_to_shuffle_index = get_path_to("shuffle_index.npy")
cache_hit = all(
map(
os.path.isfile,
[
path_to_description,
path_to_document_index,
path_to_sample_index,
path_to_shuffle_index,
],
)
)
else:
cache_hit = False
if not path_to_cache or (
not cache_hit
and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0)
):
log_single_rank(
logger,
logging.INFO,
f"Build and save the {type(self).__name__} {self.index_split.name} indices",
)
self.built_anew_on_cache_miss = True
t_beg = time.time()
sequence_length = self.config.sequence_length
num_tokens_per_epoch = self._get_num_tokens_per_epoch()
num_epochs = self._get_num_epochs(num_tokens_per_epoch)
if num_epochs == 1:
separate_final_epoch = False
else:
# Get the number of samples for the last epoch
num_samples_sans_final_epoch = (
(num_epochs - 1) * num_tokens_per_epoch
- self.config.add_extra_token_to_sequence
) // sequence_length
num_samples_from_final_epoch = self.num_samples - num_samples_sans_final_epoch
num_samples_per_epoch = (
num_tokens_per_epoch - self.config.add_extra_token_to_sequence
) // sequence_length
# num_samples_from_final_epoch should be non-negative
assert num_samples_from_final_epoch >= 0
# num_samples_from_final_epoch should not exceed max value
assert num_samples_from_final_epoch <= num_samples_per_epoch + 1
# Separate the final epoch if it falls below the threshold
threshold = 0.80
separate_final_epoch = num_samples_from_final_epoch < int(
threshold * num_samples_per_epoch
)
log_single_rank(
logger,
logging.DEBUG,
f"> num_samples_from_final_epoch: {num_samples_from_final_epoch}",
)
log_single_rank(logger, logging.DEBUG, f"> threshold: {threshold}")
log_single_rank(
logger, logging.DEBUG, f"> num_samples_per_epoch: {num_samples_per_epoch}"
)
log_single_rank(
logger, logging.DEBUG, f"> separate_final_epoch: {separate_final_epoch}"
)
numpy_random_state = numpy.random.RandomState(self.config.random_seed)
# Build the document index
document_index = _build_document_index(
self.indices, num_epochs, numpy_random_state, separate_final_epoch
)
drop_last_partial_sequence = True
if self.index_split == Split.valid:
drop_last_partial_sequence = self.config.drop_last_partial_validation_sequence
# Build the sample index
from megatron.core.datasets import helpers
if self.index_split == Split.valid:
drop_last_partial_sequence = self.config.drop_last_partial_validation_sequence
else:
drop_last_partial_sequence = True
assert document_index.dtype == numpy.int32
assert self.dataset.sequence_lengths.dtype == numpy.int32
if len(document_index) * 2 > len(self.dataset.sequence_lengths):
# If "access density" of sequence_lengths is high, force load the mmap-ed array
# into memory by making a copy.
#
# System performance benefits come from two aspects:
# 1. We sequentially pre-load the whole file, most of which we expect to read
# 2. The GIL is held when entering the c++ program, improving the speed of which
# improves parallelism
sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy()
else:
sequence_lengths_for_cpp = self.dataset.sequence_lengths
sample_index = helpers.build_sample_idx(
sequence_lengths_for_cpp,
document_index,
sequence_length,
num_epochs,
num_tokens_per_epoch,
drop_last_partial_sequence,
self.config.add_extra_token_to_sequence,
)
# Build the shuffle index
if separate_final_epoch:
shuffle_index = _build_shuffle_index(
num_samples_sans_final_epoch, sample_index.shape[0] - 1, numpy_random_state
)
else:
shuffle_index = _build_shuffle_index(
sample_index.shape[0] - 1, sample_index.shape[0] - 1, numpy_random_state
)
if path_to_cache:
os.makedirs(path_to_cache, exist_ok=True)
# Write the description
with open(path_to_description, "wt") as writer:
writer.write(self.unique_description)
numpy.save(path_to_document_index, document_index, allow_pickle=True)
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
numpy.save(path_to_shuffle_index, shuffle_index, allow_pickle=True)
else:
log_single_rank(
logger,
logging.WARNING,
f"Unable to save {type(self).__name__} indexes because path_to_cache is None",
)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
)
log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
return document_index, sample_index, shuffle_index
log_single_rank(
logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
)
log_single_rank(
logger,
logging.INFO,
f"\tLoad the document index from {os.path.basename(path_to_document_index)}",
)
t_beg = time.time()
document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r')
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger,
logging.INFO,
f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
)
t_beg = time.time()
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode='r')
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger,
logging.INFO,
f"\tLoad the shuffle index from {os.path.basename(path_to_shuffle_index)}",
)
t_beg = time.time()
shuffle_index = numpy.load(path_to_shuffle_index, allow_pickle=True, mmap_mode='r')
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0] - 1}"
)
return document_index, sample_index, shuffle_index
def _get_num_tokens_per_epoch(self) -> int:
"""Calculate the number of tokens in a single epoch
Returns:
int: The number of tokens in a single epoch
"""
return int(numpy.sum(self.dataset.sequence_lengths[self.indices]))
def _get_num_epochs(self, num_tokens_per_epoch: int) -> int:
"""Calculate the number of epochs
Args:
num_tokens_per_epoch (int): The number of tokens in a single epoch
Returns:
int: The number of epochs
"""
num_epochs = 1
num_tokens = num_tokens_per_epoch
if self.num_samples is None:
return num_epochs
else:
num_tokens_requested = (
self.num_samples * self.config.sequence_length
) + self.config.add_extra_token_to_sequence
while num_tokens < num_tokens_requested:
num_epochs += 1
num_tokens += num_tokens_per_epoch
return num_epochs
def _build_document_index(
documents: numpy.ndarray,
num_epochs: int,
numpy_random_state: numpy.random.RandomState,
separate_final_epoch: bool,
) -> numpy.ndarray:
"""Build an array with length = num epochs * num documents
Args:
documents (numpy.ndarray): the subset of exposed document indices
num_epochs (int): The number of epochs
numpy_random_state (numpy.random.RandomState): The NumPy random state
separate_final_epoch (bool): Whether to exclude the last epoch from the global shuffle
Returns:
numpy.ndarray: The document index
"""
if not separate_final_epoch or num_epochs == 1:
document_index = numpy.mgrid[0:num_epochs, 0 : len(documents)][1]
document_index[:] = documents
document_index = document_index.reshape(-1)
document_index = document_index.astype(numpy.int32)
numpy_random_state.shuffle(document_index)
return document_index
doc_idx_first = _build_document_index(documents, num_epochs - 1, numpy_random_state, False)
doc_idx_last = _build_document_index(documents, 1, numpy_random_state, False)
return numpy.concatenate((doc_idx_first, doc_idx_last))
def _build_shuffle_index(
num_samples: int, total_size: int, numpy_random_state: numpy.random.RandomState
) -> numpy.ndarray:
"""Build the range [0, size) and shuffle
Args:
num_samples (int): The size of the first shuffle range [0, num_samples)
total_size (int): The size of the entire index. If larger than 'num_samples', it defines
the second shuffle range [num_samples, total_size)
numpy_random_state (numpy.random.RandomState): The NumPy random state
Returns:
numpy.ndarray: The shuffle index
"""
dtype_ = numpy.uint32
if total_size >= (numpy.iinfo(numpy.uint32).max - 1):
dtype_ = numpy.int64
shuffle_idx_first = numpy.arange(start=0, stop=num_samples, step=1, dtype=dtype_)
numpy_random_state.shuffle(shuffle_idx_first)
if num_samples == total_size:
return shuffle_idx_first
shuffle_idx_last = numpy.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_)
numpy_random_state.shuffle(shuffle_idx_last)
return numpy.concatenate((shuffle_idx_first, shuffle_idx_last))
def _get_ltor_masks_and_position_ids(
data: torch.Tensor,
eod_token: int,
reset_position_ids: bool,
reset_attention_mask: bool,
eod_mask_loss: bool,
create_attention_mask: bool,
):
"""Build masks and position id for left to right model.
Args:
data (torch.Tensor): The data tenor that holds the tokens from the dataset
eod_token (int): ID of the token to that is considered the EOD
reset_position_ids (bool): Switch to reset the document position ID's
reset_attention_mask (bool): Switch to reset the attention mask
eod_mask_loss (bool): Switch to enable the EOD mask loss
create_attention_mask (bool): Switch to enable the attention masks generation. Can be
disabled if attention kernel generates masks by itself.
Returns:
torch.Tensor: Attention mask needed to be used for Attention
torch.Tensor: The mask used for loss value during training
torch.Tensor: The position ID's of the token
"""
seq_length = data.numel()
if create_attention_mask:
attention_mask = torch.tril(
torch.ones((seq_length, seq_length), device=data.device)
).unsqueeze(0)
else:
attention_mask = None
# Loss mask.
loss_mask = torch.ones(seq_length, dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Find indices where EOD token is.
eod_index = position_ids[data == eod_token]
# Detach indices from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indices:
prev_index = 0
for j in range(eod_index.numel()):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask and attention_mask is not None:
attention_mask[0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[(i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
if attention_mask is not None:
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
class MockGPTLowLevelDataset:
"""The mock GPT low level dataset
This class is meant to generate tokenized data in the classic "Megatron-LM" GPT style. Notably,
we add the end of document token to each element indexed in __getitem__
Args:
tokenizer (MegatronTokenizer): The tokenizer the special token information of which we use
to augment the mock data.
"""
seed: int = 0
"""The hard-coded random seed to use to set the NumPy RNG"""
size: int = 100000
"""The hard-coded number of samples to generate"""
max_sequence_length: int = 4096
"""The hard-coded max sequence length to generate"""
def __init__(self, tokenizer: MegatronTokenizer) -> None:
self.tokenizer = tokenizer
rng = numpy.random.default_rng(seed=self.seed)
self.sequence_lengths = rng.integers(
low=1, high=self.max_sequence_length, size=self.size, dtype=numpy.int32
)
def __len__(self) -> int:
return self.size
def __getitem__(self, idx: int) -> numpy.number:
length = self.sequence_lengths[idx]
sample = numpy.int64(
numpy.concatenate([numpy.arange(length - 1) + 1, [self.tokenizer.eod]])
)
return sample
def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray:
"""This function is n abstraction over __getitem__ with support for slicing
Args:
idx (int): The index into the dataset
offset (int): The integer token offset in the sequence
length (Optional[int]): The number of tokens to grab from the sequence
Returns:
numpy.ndarray: The sequence tokens at the index
"""
if length is None:
length = self.sequence_lengths[idx] - offset
return self[idx][offset : offset + length]
class MockGPTDataset(GPTDataset):
"""The mock GPT dataset
Args:
indexed_dataset (MockGPTLowLevelDataset): The MockGPTLowLevelDataset around which to build
the MockGPTDataset
dataset_path (Optional[str]): This argument is of no consequence for the MockGPTDataset
indices (numpy.ndarray): The set of the dataset indices to expose
num_samples (int): The number of samples to draw from the dataset
index_split (Split): The indices Split
config (GPTDatasetConfig): The config
"""
def __init__(
self,
dataset: MockGPTLowLevelDataset,
dataset_path: Optional[str],
indices: numpy.ndarray,
num_samples: int,
index_split: Split,
config: GPTDatasetConfig,
) -> None:
assert config.mock
super().__init__(dataset, dataset_path, indices, num_samples, index_split, config)
@staticmethod
def numel_low_level_dataset(low_level_dataset: MockGPTLowLevelDataset) -> int:
"""Abstract method implementation
Args:
low_level_dataset (MockGPTLowLevelDataset): The underlying MockGPTLowLevelDataset
Returns:
int: The number of unique elements in the underlying MockGPTLowLevelDataset
"""
return len(low_level_dataset)
@staticmethod
def build_low_level_dataset(
dataset_path: Optional[str], config: GPTDatasetConfig
) -> MockGPTLowLevelDataset:
"""Abstract method implementation
Args:
dataset_path (Optional[str]): This argument is of no consequence for the
MockGPTLowLevelDataset
config (GPTDatasetConfig): The config
Returns:
MockGPTLowLevelDataset: The underlying MockGPTLowLevelDataset
"""
return MockGPTLowLevelDataset(config.tokenizer)
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
/* Helper methods for fast index mapping builds */
#include <algorithm>
#include <iostream>
#include <limits>
#include <math.h>
#include <set>
#include <stdexcept>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <random>
namespace py = pybind11;
using namespace std;
const int32_t LONG_SENTENCE_LEN = 512;
void build_exhaustive_blending_indices(py::array_t<int16_t> &dataset_index, py::array_t<int64_t> &dataset_sample_index, const py::array_t<int64_t> &sizes, const int32_t num_datasets) {
/*
Build blending indices by sampling exactly as many samples from dataset[i]
as is requested by sizes[i] for all i in the range [0, num_datasets).
*/
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto sizes_ptr = sizes.unchecked<1>();
int64_t total_size = 0;
int64_t dataset_sample_counts[num_datasets];
std::set<int32_t> dataset_unspent_indices;
for (int32_t i = 0; i < num_datasets; ++i) {
total_size += sizes_ptr[i];
dataset_sample_counts[i] = 0;
dataset_unspent_indices.insert(i);
}
// still need fractional weights to sample in proportion to sizes
double weights[num_datasets];
for (int32_t i = 0; i < num_datasets; ++i) {
weights[i] = sizes_ptr[i] / static_cast<double>(total_size);
}
int64_t index_sample = 0;
while (dataset_unspent_indices.size() > 0) {
double index_sample_double = std::max(static_cast<double>(index_sample), 1.0);
int64_t error_argmax;
double error_max = std::numeric_limits<double>::lowest();
for (int32_t index_dataset : dataset_unspent_indices) {
double error = weights[index_dataset] * index_sample_double - static_cast<double>(dataset_sample_counts[index_dataset]);
if (error > error_max) {
error_argmax = index_dataset;
error_max = error;
}
}
// Populate the indices.
dataset_index_ptr[index_sample] = static_cast<int16_t>(error_argmax);
dataset_sample_index_ptr[index_sample] = dataset_sample_counts[error_argmax];
// Update the total samples.
dataset_sample_counts[error_argmax] += 1;
if (sizes_ptr[error_argmax] - static_cast<double>(dataset_sample_counts[error_argmax]) == 0) {
dataset_unspent_indices.erase(error_argmax);
}
index_sample += 1;
}
}
void build_blending_indices(py::array_t<int16_t> &dataset_index,
py::array_t<int64_t> &dataset_sample_index,
const py::array_t<double> &weights,
const int32_t num_datasets,
const int64_t size, const bool verbose)
{
/* Given multiple datasets and a weighting array, build samples
such that it follows those wieghts.*/
if (verbose)
{
std::cout << "> building indices for blended datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
int64_t current_samples[num_datasets];
for (int64_t i = 0; i < num_datasets; ++i)
{
current_samples[i] = 0;
}
// For each sample:
for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx)
{
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx)
{
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(current_samples[dataset_idx]);
if (error > max_error)
{
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<int16_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index];
// Update the total samples.
current_samples[max_error_index] += 1;
}
// print info
if (verbose)
{
std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx)
{
auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " << weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}
template <typename T>
py::array_t<T> build_sample_idx(
const py::array_t<int32_t> &sizes_,
const py::array_t<int32_t> &document_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch,
const bool drop_last_partial_sequence = true,
const int add_extra_token_to_sequence = 1
){
/*
Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened
and the samples are built based on this 1-D flatten array. It is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is
the starting offset in that document.
*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto document_idx = document_idx_.unchecked<1>();
// Build the sample idx as a contiguous 1-D array of type T.
int64_t num_samples = 0;
if (drop_last_partial_sequence == true) {
num_samples = (num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length;
}
else {
num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length);
}
T *sample_idx = new T[2 * (num_samples + 1)];
// Index into sample_idx.
int64_t sample_idx_index = 0;
// Index into document_idx.
T document_idx_index = 0;
// Begining offset for each document.
T doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_idx_index] = document_idx_index;
sample_idx[2 * sample_idx_index + 1] = doc_offset;
++sample_idx_index;
while (sample_idx_index <= num_samples)
{
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + add_extra_token_to_sequence;
while (remaining_seq_length != 0)
{
// Get the document length.
auto document_index = document_idx[document_idx_index];
auto document_length = sizes[document_index] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= document_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0)
{
doc_offset += (remaining_seq_length + document_length - add_extra_token_to_sequence);
remaining_seq_length = 0;
}
else
{
// Otherwise, start from the begining of the next document.
if (document_idx_index == (document_idx_.shape(0) - 1))
{
// If we have reached the end of the documents, break.
assert(sample_idx_index == num_samples);
doc_offset = sizes[document_idx[document_idx_index]] - add_extra_token_to_sequence;
break;
}
++document_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_idx_index] = document_idx_index;
sample_idx[2 * sample_idx_index + 1] = doc_offset;
++sample_idx_index;
}
// Method to deallocate memory.
py::capsule free_when_done(
sample_idx,
[](void *mem_){
T *mem = reinterpret_cast<T*>(mem_);
delete[] mem;
}
);
// Return the numpy array.
const auto byte_size = sizeof(T);
return py::array_t<T>(
std::vector<int64_t>{num_samples + 1, 2}, // shape
{2 * byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done // numpy array references
);
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937 &rand32_gen)
{
/* Training sample length. */
if (short_seq_ratio == 0)
{
return max_length;
}
const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0)
{
return 2 + random_number % (max_length - 1);
}
return max_length;
}
template <typename DocIdx>
py::array build_mapping_impl(const py::array_t<int64_t> &docs_,
const py::array_t<int32_t> &sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const double short_seq_prob,
const int32_t seed,
const bool verbose,
const int32_t min_num_sent)
{
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int.
int32_t short_seq_ratio = 0;
if (short_seq_prob > 0)
{
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose)
{
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl
<< std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 << endl
<< std::flush;
cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
<< std::flush;
cout << " total number of sentences: " << num_sentences << endl
<< std::flush;
cout << " number of epochs: " << num_epochs << endl
<< std::flush;
cout << " maximum number of samples: " << max_num_samples << endl
<< std::flush;
cout << " maximum sequence length: " << max_seq_length << endl
<< std::flush;
cout << " short sequence probability: " << short_seq_prob << endl
<< std::flush;
cout << " short sequence ration (1/prob): " << short_seq_ratio << endl
<< std::flush;
cout << " seed: " << seed << endl
<< std::flush;
}
// Mapping and it's length (1D).
int64_t num_samples = -1;
DocIdx *maps = NULL;
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration = 0; iteration < 2; ++iteration)
{
// Set the seed so both iterations produce the same results.
std::mt19937 rand32_gen(seed);
// Set the flag on second iteration.
second = (iteration == 1);
// Counters:
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// Current map index.
uint64_t map_index = 0;
// For each epoch:
for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
{
if (map_index >= max_num_samples)
{
if (verbose && (!second))
{
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl
<< std::flush;
}
break;
}
// For each document:
for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
{
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second))
{
if (num_remain_sent == 0)
{
++empty_docs;
}
if (num_remain_sent == 1)
{
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent > 1)
{
for (auto sent_index = sent_index_first;
sent_index < sent_index_last; ++sent_index)
{
if (sizes[sent_index] > LONG_SENTENCE_LEN)
{
if ((epoch == 0) && (!second))
{
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have more than two sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
{
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
auto target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
// Loop through sentences.
for (auto sent_index = sent_index_first;
sent_index < sent_index_last; ++sent_index)
{
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent >= min_num_sent)) ||
(num_remain_sent == 0))
{
// Check for overflow.
if ((3 * map_index + 2) >
std::numeric_limits<int64_t>::max())
{
cout << "number of samples exceeded maximum "
<< "allowed by type int64: "
<< std::numeric_limits<int64_t>::max()
<< endl;
throw std::overflow_error("Number of samples");
}
// Populate the map.
if (second)
{
const auto map_index_0 = 3 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(target_seq_len);
}
// Update indices / counters.
++map_index;
prev_start_index = sent_index + 1;
target_seq_len = get_target_sample_len(short_seq_ratio,
max_seq_length,
rand32_gen);
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second)
{
if (verbose)
{
cout << " number of empty documents: " << empty_docs << endl
<< std::flush;
cout << " number of documents with one sentence: " << one_sent_docs << endl
<< std::flush;
cout << " number of documents with long sentences: " << long_sent_docs << endl
<< std::flush;
cout << " will create mapping for " << map_index << " samples" << endl
<< std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3 * map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i = (num_samples - 1); i > 0; --i)
{
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_)
{
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem; });
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3 * byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_mapping(const py::array_t<int64_t> &docs_,
const py::array_t<int> &sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const double short_seq_prob,
const int seed,
const bool verbose,
const int32_t min_num_sent)
{
if (sizes_.size() > std::numeric_limits<uint32_t>::max())
{
if (verbose)
{
cout << " using uint64 for data mapping..." << endl
<< std::flush;
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
}
else
{
if (verbose)
{
cout << " using uint32 for data mapping..." << endl
<< std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose,
min_num_sent);
}
}
template <typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t> &docs_,
const py::array_t<int32_t> &sizes_,
const py::array_t<int32_t> &titles_sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose,
const bool use_one_sent_blocks)
{
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
auto titles_sizes = titles_sizes_.unchecked<1>();
if (verbose)
{
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl
<< std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 << endl
<< std::flush;
cout << " sentences range: [" << sent_start_index << ", " << sent_end_index << ")" << endl
<< std::flush;
cout << " total number of sentences: " << num_sentences << endl
<< std::flush;
cout << " number of epochs: " << num_epochs << endl
<< std::flush;
cout << " maximum number of samples: " << max_num_samples << endl
<< std::flush;
cout << " maximum sequence length: " << max_seq_length << endl
<< std::flush;
cout << " seed: " << seed << endl
<< std::flush;
}
// Mapping and its length (1D).
int64_t num_samples = -1;
DocIdx *maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks)
{
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration = 0; iteration < 2; ++iteration)
{
// Set the flag on second iteration.
second = (iteration == 1);
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch = 0; epoch < num_epochs; ++epoch)
{
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples)
{
if (verbose && (!second))
{
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl
<< std::flush;
}
break;
}
// For each document:
for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc)
{
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
const auto target_seq_len = max_seq_length - titles_sizes[doc];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second))
{
if (num_remain_sent == 0)
{
++empty_docs;
}
if (num_remain_sent == 1)
{
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent >= min_num_sent)
{
for (auto sent_index = sent_index_first;
sent_index < sent_index_last; ++sent_index)
{
if (sizes[sent_index] > LONG_SENTENCE_LEN)
{
if ((epoch == 0) && (!second))
{
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence))
{
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
// Loop through sentences.
for (auto sent_index = sent_index_first;
sent_index < sent_index_last; ++sent_index)
{
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent)) ||
(num_remain_sent == 0))
{
// Populate the map.
if (second)
{
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second)
{
if (verbose)
{
cout << " number of empty documents: " << empty_docs << endl
<< std::flush;
cout << " number of documents with one sentence: " << one_sent_docs << endl
<< std::flush;
cout << " number of documents with long sentences: " << long_sent_docs << endl
<< std::flush;
cout << " will create mapping for " << map_index << " samples" << endl
<< std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[4 * map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i = (num_samples - 1); i > 0; --i)
{
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_)
{
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem; });
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4 * byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_blocks_mapping(const py::array_t<int64_t> &docs_,
const py::array_t<int> &sizes_,
const py::array_t<int> &titles_sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose,
const bool use_one_sent_blocks)
{
if (sizes_.size() > std::numeric_limits<uint32_t>::max())
{
if (verbose)
{
cout << " using uint64 for data mapping..." << endl
<< std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
else
{
if (verbose)
{
cout << " using uint32 for data mapping..." << endl
<< std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
PYBIND11_MODULE(helpers_cpp, m)
{
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx_int32", &build_sample_idx<int32_t>);
m.def("build_sample_idx_int64", &build_sample_idx<int64_t>);
m.def("build_blending_indices", &build_blending_indices);
m.def("build_exhaustive_blending_indices", &build_exhaustive_blending_indices);
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import numpy
# Implicit imports for backwards compatibility
# Explicit imports for readability
from megatron.core.datasets.helpers_cpp import *
from megatron.core.datasets.helpers_cpp import build_sample_idx_int32, build_sample_idx_int64
def build_sample_idx(
sizes: numpy.ndarray,
document_indices: numpy.ndarray,
sequence_length: int,
num_epochs: int,
tokens_per_epoch: int,
drop_last_partial_sequence: bool = True,
add_extra_token_to_sequence: bool = True,
):
"""Build the 2-D sample index using the properly typed templated C++ function from helpers.cpp
Args:
sizes (numpy.ndarray): The 1-D array of document lengths
document_indices (numpy.ndarray): The 1-D array of document indices
sequence_length (int): The sequence length
num_epochs (int): The number of epochs
tokens_per_epoch (int): The number of tokens per epoch
drop_last_partial_sequence (bool): Whether to omit the last partial sequence in the sample
index should it exist. Defaults to True.
add_extra_token_to_sequence (bool): Whether to build samples with sequence length
`sequence_length + 1`. Defaults to True.
Returns:
numpy.ndarray: The 2-D sample index
"""
sample_idx_max = max(document_indices.shape[0], sizes.max())
if sample_idx_max <= numpy.iinfo(numpy.int32).max:
sample_idx = build_sample_idx_int32(
sizes,
document_indices,
sequence_length,
num_epochs,
tokens_per_epoch,
drop_last_partial_sequence,
1 if add_extra_token_to_sequence else 0,
)
assert sample_idx.min() >= 0 and sample_idx.max() <= sample_idx_max
else:
sample_idx = build_sample_idx_int64(
sizes,
document_indices,
sequence_length,
num_epochs,
tokens_per_epoch,
drop_last_partial_sequence,
1 if add_extra_token_to_sequence else 0,
)
return sample_idx
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Essentially re-written in entirety
import logging
import os
import shutil
import struct
import time
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from itertools import accumulate
from types import TracebackType
from typing import List, Optional, Tuple, Type, Union
try:
import boto3
except ModuleNotFoundError:
pass
import numpy
import torch
from megatron.core.datasets.utils_s3 import (
S3Config,
is_s3_path,
maybe_download_file,
object_exists,
parse_s3_path,
)
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
_INDEX_HEADER = b"MMIDIDX\x00\x00"
class DType(Enum):
"""The NumPy data type Enum for writing/reading the IndexedDataset indices"""
uint8 = 1
int8 = 2
int16 = 3
int32 = 4
int64 = 5
float64 = 6
float32 = 7
uint16 = 8
@classmethod
def code_from_dtype(cls, value: Type[numpy.number]) -> int:
"""Get the code from the dtype
Args:
value (Type[numpy.number]): The dtype
Returns:
int: The code
"""
return cls[value.__name__].value
@classmethod
def dtype_from_code(cls, value: int) -> Type[numpy.number]:
"""Get the dtype from the code
Args:
value (int): The code
Returns:
Type[numpy.number]: The dtype
"""
return getattr(numpy, cls(value).name)
@staticmethod
def size(key: Union[int, Type[numpy.number]]) -> int:
"""Get the size of the dtype/code in bytes
Args:
key (Union[int, Type[numpy.number]]): The dtype or code
Raises:
ValueError: If the key is neither dtype nor integer code
Returns:
int: The size of the dtype/code in in bytes
"""
if isinstance(key, int):
return DType.dtype_from_code(key)().itemsize
elif numpy.number in key.__mro__:
return key().itemsize
else:
raise ValueError
@staticmethod
def optimal_dtype(cardinality: Optional[int]) -> Type[numpy.number]:
"""Get the dtype to use for an index of a certain cardinality
Args:
cardinality (Optional[int]): The number of elements to be indexed
Returns:
Type[numpy.number]: The dtype to use for the index
"""
if cardinality is not None and cardinality < 65500:
return numpy.uint16
else:
return numpy.int32
class _IndexWriter(object):
"""Object class to write the index (.idx) file
Args:
idx_path (str): The path to the index file
dtype (Type[numpy.number]): The dtype of the index file
"""
def __init__(self, idx_path: str, dtype: Type[numpy.number]) -> None:
self.idx_path = idx_path
self.dtype = dtype
def __enter__(self) -> "_IndexWriter":
"""Enter the context introduced by the 'with' keyword
Returns:
_IndexWriter: The instance
"""
self.idx_writer = open(self.idx_path, "wb")
# fixed, vestigial practice
self.idx_writer.write(_INDEX_HEADER)
# fixed, vestigial practice
self.idx_writer.write(struct.pack("<Q", 1))
# the numeric code for the dtype
self.idx_writer.write(struct.pack("<B", DType.code_from_dtype(self.dtype)))
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> Optional[bool]:
"""Exit the context introduced by the 'with' keyword
Args:
exc_type (Optional[Type[BaseException]]): Exception type
exc_val (Optional[BaseException]): Exception value
exc_tb (Optional[TracebackType]): Exception traceback object
Returns:
Optional[bool]: Whether to silence the exception
"""
self.idx_writer.close()
def write(
self,
sequence_lengths: List[int],
sequence_modes: Optional[List[int]],
document_indices: List[int],
) -> None:
"""Write the index (.idx) file
Args:
sequence_lengths (List[int]): The length of each sequence
sequence_modes (Optional[List[int]]): The mode of each sequences
document_indices (List[int]): The seqyebce indices demarcating the end of each document
"""
sequence_pointers = self._sequence_pointers(sequence_lengths)
# the number of sequences in the dataset
sequence_count = len(sequence_lengths)
self.idx_writer.write(struct.pack("<Q", sequence_count))
# the number of documents in the dataset
document_count = len(document_indices)
self.idx_writer.write(struct.pack("<Q", document_count))
# the number of tokens per sequence
sequence_lengths = numpy.array(sequence_lengths, dtype=numpy.int32)
self.idx_writer.write(sequence_lengths.tobytes(order="C"))
del sequence_lengths
# the byte offsets for all sequences
sequence_pointers = numpy.array(sequence_pointers, dtype=numpy.int64)
self.idx_writer.write(sequence_pointers.tobytes(order="C"))
del sequence_pointers
# the sequence indices marking the end of each document
document_indices = numpy.array(document_indices, dtype=numpy.int64)
self.idx_writer.write(document_indices.tobytes(order="C"))
# the mode per sequence
if sequence_modes is not None:
sequence_modes = numpy.array(sequence_modes, dtype=numpy.int8)
self.idx_writer.write(sequence_modes.tobytes(order='C'))
del sequence_modes
def _sequence_pointers(self, sequence_lengths: List[int]) -> List[int]:
"""Build the sequence pointers per the sequence lengths and dtype size
Args:
sequence_lengths (List[int]): The length of each sequence
Returns:
List[int]: The pointer to the beginning of each sequence
"""
itemsize = DType.size(self.dtype)
curr_ptr = 0
list_ptr = []
for length in sequence_lengths:
list_ptr.append(curr_ptr)
curr_ptr += length * itemsize
return list_ptr
class _IndexReader(object):
"""Object class to read the index (.idx) file
Args:
idx_path (str): The path to the index file
multimodal (bool): Whether the dataset is multimodal
"""
def __init__(self, idx_path: str, multimodal: bool) -> None:
log_single_rank(logger, logging.INFO, f"Load the {type(self).__name__} from {idx_path}")
with open(idx_path, "rb") as stream:
header = stream.read(9)
assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}"
version = struct.unpack("<Q", stream.read(8))[0]
assert version == 1, f"bad version, cannot read: {idx_path}"
code = struct.unpack("<B", stream.read(1))[0]
self.dtype = DType.dtype_from_code(code)
self.dtype_size = DType.size(self.dtype)
self.sequence_count = struct.unpack("<Q", stream.read(8))[0]
self.document_count = struct.unpack("<Q", stream.read(8))[0]
offset = stream.tell()
self.bin_buffer_mmap = numpy.memmap(idx_path, mode="r", order="C")
self.bin_buffer = memoryview(self.bin_buffer_mmap)
log_single_rank(logger, logging.INFO, f"\tExtract the sequence lengths")
t_beg = time.time()
self.sequence_lengths = numpy.frombuffer(
self.bin_buffer, dtype=numpy.int32, count=self.sequence_count, offset=offset
)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(logger, logging.INFO, f"\tExtract the sequence pointers")
t_beg = time.time()
self.sequence_pointers = numpy.frombuffer(
self.bin_buffer,
dtype=numpy.int64,
count=self.sequence_count,
offset=offset + self.sequence_lengths.nbytes,
)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(logger, logging.INFO, f"\tExtract the document indices")
t_beg = time.time()
self.document_indices = numpy.frombuffer(
self.bin_buffer,
dtype=numpy.int64,
count=self.document_count,
offset=offset + self.sequence_lengths.nbytes + self.sequence_pointers.nbytes,
)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
self.sequence_modes = None
if multimodal:
log_single_rank(logger, logging.INFO, f"\tExtract the sequence modes")
t_beg = time.time()
self.sequence_modes = numpy.frombuffer(
self.bin_buffer,
dtype=numpy.int8,
count=self.sequence_count,
offset=offset
+ self.sequence_lengths.nbytes
+ self.sequence_pointers.nbytes
+ self.document_indices.nbytes,
)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
assert self.sequence_lengths.shape[0] == len(self)
assert self.sequence_lengths.shape[0] == self.sequence_count
assert self.sequence_lengths.shape[0] == self.document_indices[-1]
log_single_rank(logger, logging.INFO, f"> total number of sequences: {len(self)}")
log_single_rank(
logger,
logging.INFO,
f"> total number of documents: {self.document_indices.shape[0] - 1}",
)
def __del__(self) -> None:
"""Clean up the object"""
if hasattr(self, "bin_buffer_mmap"):
self.bin_buffer_mmap._mmap.close()
del self.bin_buffer_mmap
def __len__(self) -> int:
"""Return the length of the dataset
Returns:
int: The length of the dataset
"""
return self.sequence_count
@lru_cache(maxsize=8)
def __getitem__(self, idx: int) -> Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]:
"""Return the pointer, length, and mode at the index
Args:
idx (int): The index into the dataset
Returns:
Tuple[numpy.int32, numpy.int64, Optional[numpy.int8]]: The pointer, length and mode at the index
"""
return (
self.sequence_pointers[idx],
self.sequence_lengths[idx],
self.sequence_modes[idx] if self.sequence_modes is not None else None,
)
class _BinReader(ABC):
"""Abstract class to read the data (.bin) file"""
@abstractmethod
def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray:
"""Read bytes into a numpy array.
Args:
dtype (Type[numpy.number]): Data-type of the returned array.
count (int): Number of items to read.
offset (int): Start reading from this offset (in bytes).
Returns:
numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`.
"""
pass
class _MMapBinReader(_BinReader):
"""A _BinReader that memory maps the data (.bin) file
Args:
bin_path (str): bin_path (str): The path to the data (.bin) file.
"""
def __init__(self, bin_path: str) -> None:
self._bin_buffer_mmap = numpy.memmap(bin_path, mode="r", order="C")
self._bin_buffer = memoryview(self._bin_buffer_mmap)
def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray:
"""Read bytes into a numpy array.
Args:
dtype (Type[numpy.number]): Data-type of the returned array.
count (int): Number of items to read.
offset (int): Start reading from this offset (in bytes).
Returns:
numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`.
"""
return numpy.frombuffer(self._bin_buffer, dtype=dtype, count=count, offset=offset)
def __del__(self) -> None:
"""Clean up the object."""
if self._bin_buffer_mmap is not None:
self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap
class _FileBinReader(_BinReader):
"""A _BinReader that reads from the data (.bin) file using a file pointer
Args:
bin_path (str): bin_path (str): The path to the data (.bin) file.
"""
def __init__(self, bin_path: str) -> None:
self._bin_path = bin_path
def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray:
"""Read bytes into a numpy array.
Args:
dtype (Type[numpy.number]): Data-type of the returned array.
count (int): Number of items to read.
offset (int): Start reading from this offset (in bytes).
Returns:
numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`.
"""
sequence = numpy.empty(count, dtype=dtype)
with open(self._bin_path, mode='rb', buffering=0) as bin_buffer_file:
bin_buffer_file.seek(offset)
bin_buffer_file.readinto(sequence)
return sequence
class _S3BinReader(_BinReader):
"""A _BinReader that reads from the data (.bin) file from S3
Args:
bin_path (str): bin_path (str): The path to the data (.bin) file.
bin_chunk_nbytes (int, optional): If not None, then maintain an in-memory cache to speed up calls to the `read` method. Furthermore, on a cache miss, download this number of bytes to refresh the cache. Otherwise (None), do not maintain an in-memory cache. A class that inherits from _BinReader may not implement caching in which case it should assert that `bin_chunk_nbytes` is None at initialization.
"""
def __init__(self, bin_path: str, bin_chunk_nbytes: int) -> None:
assert bin_chunk_nbytes > 0
self._client = boto3.client("s3")
self._s3_bucket, self._s3_key = parse_s3_path(bin_path)
self._cache = None
self._cache_bytes_start = None
self._cache_bytes_end = None
self._cache_nbytes = bin_chunk_nbytes
def _extract_from_cache(self, offset: int, size: int) -> bytes:
"""Extract `size` bytes starting at `offset` bytes into the cache"""
start = offset - self._cache_bytes_start
assert start >= 0
end = start + size
assert end <= len(self._cache)
return self._cache[start:end]
def read(self, dtype: Type[numpy.number], count: int, offset: int) -> numpy.ndarray:
"""Read bytes into a numpy array.
Let `size` be the `count` * `DType.size(dtype)`. If the requested span of bytes [`offset`,
`offset` + `size`) is covered by the in-memory cache maintained by this class, then this
function extracts the requested span from that cache and returns it. Otherwise, this
function first refreshes the cache and then extracts the requested span from the refreshed
cache and returns it.
The cache is refreshed based on `offset` and `size`. In particular, we divide all the bytes
in an S3 object into blocks, where each block contains `bin_chunk_nbytes` bytes. We assign
each block an index starting from 0. We take the block with index (`offset` //
`bin_chunk_nbytes`) to refresh the cache. If this new block still does not cover the
requested span, we extend it just enough to include `offset` + `size`.
Args:
dtype (Type[numpy.number]): Data-type of the returned array.
count (int): Number of items to read.
offset (int): Start reading from this offset (in bytes).
Returns:
numpy.ndarray: An array with `count` items and data-type `dtype` constructed from reading bytes from the data file starting at `offset`.
"""
size = count * DType.size(dtype)
if (
self._cache is not None
and offset >= self._cache_bytes_start
and offset + size <= self._cache_bytes_end
):
return numpy.frombuffer(self._extract_from_cache(offset, size), dtype=dtype)
bytes_start = (offset // self._cache_nbytes) * self._cache_nbytes
assert bytes_start >= 0
assert offset >= bytes_start
bytes_end = max(bytes_start + self._cache_nbytes, offset + size)
assert bytes_end >= 1
self._cache = self._client.get_object(
Bucket=self._s3_bucket,
Key=self._s3_key,
# Subtract 1, because the end of Range is inclusive.
Range=f'bytes={bytes_start}-{bytes_end-1}',
)['Body'].read()
self._cache_bytes_start = bytes_start
self._cache_bytes_end = bytes_end
return numpy.frombuffer(self._extract_from_cache(offset, size), dtype=dtype)
def __del__(self) -> None:
"""Clean up the object"""
self._client.close()
class IndexedDataset(torch.utils.data.Dataset):
"""The low-level interface dataset class
Args:
path_prefix (str): The index (.idx) and data (.bin) prefix
multimodal (bool): Whether the dataset is multimodal. Defaults to False.
mmap (bool): Whether to mmap the .bin files. Defaults to True.
s3_config (Optional[S3Config]): Supplied only for data stored on S3. IndexedDataset downloads the index (.idx) file to `s3_config.path_to_idx_cache` and streams data from the data (.bin) file in `s3_config.bin_chunk_nbytes` blocks. Note that `mmap` must be disabled for S3 data loading. Defaults to None.
"""
def __init__(
self,
path_prefix: str,
multimodal: bool = False,
mmap: bool = True,
s3_config: Optional[S3Config] = None,
) -> None:
super().__init__()
self.path_prefix = None
self.multimodal = None
self.mmap = None
self.s3_config = None
self.index = None
self.bin_reader = None
if is_s3_path(path_prefix) and s3_config is not None:
idx_path = get_idx_path(path_prefix)
cache_idx_path = os.path.join(s3_config.path_to_idx_cache, os.path.basename(idx_path))
maybe_download_file(idx_path, cache_idx_path)
self.initialize(path_prefix, multimodal, mmap, s3_config)
def initialize(
self, path_prefix: str, multimodal: bool, mmap: bool, s3_config: Optional[S3Config]
) -> None:
"""Initialize the dataset
This method is called by IndexedDataset.__init__ during object creation and by
IndexedDataset.__setstate__ during un-pickling
Args:
path_prefix (str): The index (.idx) and data (.bin) prefix
multimodal (bool): Whether the dataset is multimodal
mmap (bool): Whether to mmap the .bin file
s3_config (Optional[S3Config]): See IndexedDataset docstring for details.
"""
idx_path = get_idx_path(path_prefix)
bin_path = get_bin_path(path_prefix)
if s3_config is None:
assert os.path.exists(idx_path) and os.path.exists(
bin_path
), f"One or both of the .idx and .bin files cannot be found at the path prefix {path_prefix}"
self.path_prefix = path_prefix
self.multimodal = multimodal
self.mmap = mmap
self.s3_config = s3_config
if mmap:
assert not s3_config
self.bin_reader = _MMapBinReader(bin_path)
elif s3_config:
assert not mmap
self.bin_reader = _S3BinReader(bin_path, s3_config.bin_chunk_nbytes)
idx_path = os.path.join(
s3_config.path_to_idx_cache, os.path.basename(get_idx_path(path_prefix))
)
else:
self.bin_reader = _FileBinReader(bin_path)
self.index = _IndexReader(idx_path, self.multimodal)
def __getstate__(self) -> Tuple[str, bool, bool, Optional[S3Config]]:
"""Get the state during pickling
Returns:
Tuple[str, bool, bool, Optional[S3Config]]: The state tuple
"""
return self.path_prefix, self.multimodal, self.mmap, self.s3_config
def __setstate__(self, state: Tuple[str, bool, bool, Optional[S3Config]]) -> None:
"""Set the state during un-pickling
Args:
state (Tuple[str, bool, bool, Optional[S3Config]]): The state tuple
"""
path_prefix, multimodal, mmap, s3_config = state
self.initialize(path_prefix, multimodal, mmap, s3_config)
def __del__(self) -> None:
"""Clean up the object"""
del self.bin_reader
del self.index
def __len__(self) -> int:
"""Return the length of the dataset i.e. the number of sequences in the index
Returns:
int: The length of the dataset
"""
return len(self.index)
def __getitem__(
self, idx: Union[int, numpy.integer, slice]
) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
"""Return from the dataset
Args:
idx (Union[int, numpy.integer, slice]): The index or index slice into the dataset
Raises:
ValueError: When the index slice is non-contiguous
TypeError: When the index is of an unexpected type
Returns:
Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index or index slice
"""
if isinstance(idx, (int, numpy.integer)):
sequence_pointer, sequence_length, sequence_mode = self.index[idx]
sequence = self.bin_reader.read(
dtype=self.index.dtype, count=sequence_length, offset=sequence_pointer
)
return (sequence, sequence_mode) if sequence_mode is not None else sequence
elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
if step != 1:
raise ValueError("Slices into indexed_dataset must be contiguous")
sequence_lengths = self.index.sequence_lengths[idx]
sequence_modes = self.index.sequence_modes[idx] if self.multimodal else None
sequence_offsets = list(accumulate(sequence_lengths))
sequences = numpy.split(
self.bin_reader.read(
dtype=self.index.dtype,
count=sum(sequence_lengths),
offset=self.index.sequence_pointers[start],
),
sequence_offsets[:-1],
)
return (sequences, sequence_modes) if sequence_modes is not None else sequences
else:
raise TypeError("Unexpected type received for idx: {}".format(type(idx)))
def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray:
"""Retrieve a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
Args:
idx (Union[int, numpy.integer]): The index into the dataset
offset (int): The integer token offset in the sequence
length (int): The number of tokens to grab from the sequence
Returns:
Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]: The sequence tokens and modes at the index
"""
sequence_pointer, sequence_length, sequence_mode = self.index[idx]
if length is None:
length = sequence_length - offset
sequence_pointer += offset * DType.size(self.index.dtype)
sequence = self.bin_reader.read(
dtype=self.index.dtype, count=length, offset=sequence_pointer
)
return (sequence, sequence_mode) if sequence_mode is not None else sequence
@property
def sequence_lengths(self) -> numpy.ndarray:
"""Get the sequence lengths
Returns:
numpy.ndarray: The sequence lengths
"""
return self.index.sequence_lengths
@property
def document_indices(self) -> numpy.ndarray:
"""Get the document indices
Returns:
numpy.ndarray: The document indices
"""
return self.index.document_indices
def get_document_indices(self) -> numpy.ndarray:
"""Get the document indices
This method is slated for deprecation.
Returns:
numpy.ndarray: The document indices
"""
return self.index.document_indices
def set_document_indices(self, document_indices: numpy.ndarray) -> None:
"""Set the document indices
This method is slated for deprecation.
Args:
document_indices (numpy.ndarray): The document indices
"""
self.index.document_indices = document_indices
@property
def sequence_modes(self) -> numpy.ndarray:
"""Get the sequence modes
Returns:
numpy.ndarray: The sequence modes
"""
return self.index.sequence_modes
@staticmethod
def exists(path_prefix: str) -> bool:
"""Return whether the IndexedDataset exists on disk at the prefix
Args:
path_prefix (str): The prefix to the index (.idx) and data (.bin) files
Returns:
bool: Whether the IndexedDataset exists on disk at the prefix
"""
if is_s3_path(path_prefix):
s3_client = boto3.client("s3")
return object_exists(s3_client, get_idx_path(path_prefix)) and object_exists(
s3_client, get_bin_path(path_prefix)
)
return os.path.exists(get_idx_path(path_prefix)) and os.path.exists(
get_bin_path(path_prefix)
)
class IndexedDatasetBuilder(object):
"""Builder class for the IndexedDataset class
Args:
bin_path (str): The path to the data (.bin) file
dtype (Type[numpy.number], optional): The dtype of the index file. Defaults to numpy.int32.
multimodal (bool, optional): Whether the dataset is multimodal. Defaults to False.
"""
def __init__(
self, bin_path: str, dtype: Type[numpy.number] = numpy.int32, multimodal: bool = False
) -> None:
self.data_file = open(bin_path, "wb")
self.dtype = dtype
self.multimodal = multimodal
self.sequence_lengths = []
self.document_indices = [0]
self.sequence_modes = [] if self.multimodal else None
def add_item(self, tensor: torch.Tensor, mode: int = 0) -> None:
"""Add a single item to the dataset
Args:
tensor (torch.Tensor): The item to add to the data file
mode (int, optional): The mode for the item. Defaults to 0.
"""
np_array = numpy.array(tensor.numpy(), dtype=self.dtype)
self.data_file.write(np_array.tobytes(order="C"))
self.sequence_lengths.append(np_array.size)
if self.multimodal:
self.sequence_modes.append(mode)
def add_document(
self, tensor: torch.Tensor, lengths: List[int], modes: Optional[List[int]] = None
) -> None:
"""Add an entire document to the dataset
Args:
tensor (torch.Tensor): The document to add
lengths (List[int]): The lengths of each item in the document
modes (Optional[List[int]], optional): The modes for each item in the document. Defaults to None.
"""
np_array = numpy.array(tensor, dtype=self.dtype)
self.data_file.write(np_array.tobytes(order="C"))
self.sequence_lengths.extend(lengths)
self.document_indices.append(len(self.sequence_lengths))
if self.multimodal:
self.sequence_modes.extend(modes if modes is not None else [0] * lengths)
def end_document(self) -> None:
"""Finalize the document, for use with IndexedDatasetBuilder.add_item"""
self.document_indices.append(len(self.sequence_lengths))
def add_index(self, path_prefix: str) -> None:
"""Add an entire IndexedDataset to the dataset
Args:
path_prefix (str): The index (.idx) and data (.bin) prefix
"""
# Concatenate index
index = _IndexReader(get_idx_path(path_prefix), multimodal=self.multimodal)
assert index.dtype == self.dtype
offset = len(self.sequence_lengths)
self.sequence_lengths.extend(index.sequence_lengths)
self.document_indices.extend((offset + index.document_indices)[1:])
if self.multimodal:
self.sequence_modes.extend(index.sequence_modes)
# Concatenate data
with open(get_bin_path(path_prefix), "rb") as f:
shutil.copyfileobj(f, self.data_file)
def finalize(self, idx_path: str) -> None:
"""Clean up and write the index (.idx) file
Args:
idx_path (str): The path to the index file
"""
self.data_file.close()
with _IndexWriter(idx_path, self.dtype) as writer:
writer.write(self.sequence_lengths, self.sequence_modes, self.document_indices)
def get_idx_path(path_prefix: str) -> str:
"""Get the path to the index file from the prefix
Args:
path_prefix (str): The prefix
Returns:
str: The path to the index file
"""
return path_prefix + ".idx"
def get_bin_path(path_prefix: str) -> str:
"""Get the path to the data file from the prefix
Args:
path_prefix (str): The prefix
Returns:
str: The path to the data file
"""
return path_prefix + ".bin"
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import os
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy
import torch
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.megatron_dataset import MegatronDataset
from megatron.core.datasets.utils import Split
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
@dataclass
class MaskedWordPieceDatasetConfig(BlendedMegatronDatasetConfig):
"""Configuration object for Megatron Core Masked WordPiece datasets"""
masking_probability: float = None
"""The probability we mask a candidate N-gram"""
short_sequence_probability: float = None
"""The probability we return a sequence shorter than the target sequence length"""
masking_max_ngram: int = None
"""The maximum length N-gram to consider masking or permuting"""
masking_do_full_word: bool = None
"""Whether we mask the whole word or its component parts"""
masking_do_permutation: bool = None
"""Whether we shuffle a subset of candidate N-grams in addition"""
masking_use_longer_ngrams: bool = None
"""Whether to favor longer N-grams over shorter N-grams"""
masking_use_geometric_distribution: bool = None
"""Whether to draw the size of the N-gram from a geometric distribution according to SpanBERT
https://arxiv.org/abs/1907.10529 (Section 3.1)
"""
def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
assert self.tokenizer is not None
assert self.masking_probability is not None
assert self.short_sequence_probability is not None
assert self.masking_max_ngram is not None
assert self.masking_do_full_word is not None
assert self.masking_do_permutation is not None
assert self.masking_use_longer_ngrams is not None
assert self.masking_use_geometric_distribution is not None
assert self.masking_probability > 0 and self.masking_probability < 1.0
assert self.short_sequence_probability >= 0 and self.short_sequence_probability <= 1.0
assert self.masking_max_ngram > 0
assert not (self.masking_use_geometric_distribution and self.masking_do_permutation)
if self.masking_use_geometric_distribution and self.masking_use_longer_ngrams:
log_single_rank(
logger,
logging.WARNING,
"The use of a geometric distribution overrides the default distribution",
)
class MaskedWordPieceDataset(MegatronDataset):
"""The semi-abstract base class for masked WordPiece datasets
This implementation makes the rigid assumption that all inheritor datasets are built upon the
IndexedDataset class. This assumption may be pushed down to the inheritors in future if
necessary.
NB: WordPiece tokenization prepends a double hash "##" to all tokens/pieces in a word, save the
first token/piece.
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the
MegatronDataset
dataset_path (str): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed dataset.
When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (MaskedWordPieceDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: MaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
@staticmethod
def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int:
return low_level_dataset.document_indices.shape[0] - 1
@staticmethod
def build_low_level_dataset(
dataset_path: str, config: MaskedWordPieceDatasetConfig
) -> IndexedDataset:
return IndexedDataset(dataset_path)
@staticmethod
def _key_config_attributes() -> List[str]:
"""Inherited method implementation
Returns:
List[str]: The key config attributes
"""
return super(MaskedWordPieceDataset, MaskedWordPieceDataset)._key_config_attributes() + [
"masking_probability",
"short_sequence_probability",
"masking_max_ngram",
"masking_do_full_word",
"masking_do_permutation",
"masking_use_longer_ngrams",
"masking_use_geometric_distribution",
]
def __len__(self) -> int:
return self.sample_index.shape[0]
def _build_sample_index(
self, sequence_length: int, min_sentences_per_sample: int
) -> numpy.ndarray:
path_to_cache = self.config.path_to_cache
if path_to_cache is None:
path_to_cache = os.path.join(
self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices"
)
get_path_to = lambda suffix: os.path.join(
path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}"
)
path_to_description = get_path_to("description.txt")
path_to_sample_index = get_path_to("sample_index.npy")
cache_hit = all(map(os.path.isfile, [path_to_description, path_to_sample_index]))
if self.num_samples is not None:
num_epochs = numpy.iinfo(numpy.int32).max - 1
else:
num_epochs = 1
if not cache_hit and torch.distributed.get_rank() == 0:
log_single_rank(
logger,
logging.INFO,
f"Build and save the {type(self).__name__} {self.index_split.name} indices",
)
self.built_anew_on_cache_miss = True
os.makedirs(path_to_cache, exist_ok=True)
# Write the description
with open(path_to_description, "wt") as writer:
writer.write(self.unique_description)
# Build the sample index
log_single_rank(
logger,
logging.INFO,
f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}",
)
t_beg = time.time()
from megatron.core.datasets import helpers
# Add +1 for access to document upper bound
indices = numpy.append(self.indices, self.indices[-1] + 1)
sample_index = helpers.build_mapping(
self.dataset.document_indices[indices],
self.dataset.sequence_lengths,
num_epochs,
self.num_samples,
sequence_length,
self.config.short_sequence_probability,
self.config.random_seed,
False,
min_sentences_per_sample,
)
numpy.save(path_to_sample_index, sample_index, allow_pickle=True)
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
log_single_rank(
logger, logging.INFO, f"> total number of samples: {sample_index.shape[0]}"
)
log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")
return sample_index
log_single_rank(
logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"
)
log_single_rank(
logger,
logging.INFO,
f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}",
)
t_beg = time.time()
sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode="r")
t_end = time.time()
log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds")
return sample_index
def _create_masked_lm_predictions(
self,
token_ids: List[int],
target_sequence_length: int,
numpy_random_state: numpy.random.RandomState,
) -> Tuple[List[int], List[int], List[int], List[int], List[Tuple[List[int], List[int]]]]:
"""Creates the predictions for the masked LM objective
Args:
token_ids (List[int]): The token ids
target_sequence_length (int): The target sequence length
numpy_random_state (numpy.random.RandomState): The NumPy random state
Returns:
Tuple[List[int], List[int], List[int], List[int], List[Tuple[List[int], List[int]]]]:
1. masked_token_ids -> The masked sequence
2. masked_positions -> The indices for the masked token ids
3. masked_labels -> The original token ids for the masked token ids
4. boundaries -> The sentence and word boundaries for the sequence
4. masked_spans -> The masked positions and labels with N-gram info intact
"""
# Build the token sentence and word boundaries and the masking candidates
# e.g. [cls, id, ##id, ##id, id, ##id, sep, id, ##id, sep]
# -> boundaries: [1, 1, 0, 0, 1, 0, 1, 1, 0, 1]
# -> candidates with whole word masking: [[1, 2, 3], [4, 5], [7, 8]]
# -> candidates sans whole word masking: [[1], [2], [3], [4], [5], [7], [8]]
boundaries = []
candidates = []
for i, token_id in enumerate(token_ids):
if token_id == self.config.tokenizer.cls or token_id == self.config.tokenizer.sep:
boundaries.append(1)
else:
if not self.config.tokenizer.inv_vocab[token_id].startswith("##"):
boundaries.append(1)
candidates.append([i])
else:
boundaries.append(0)
if self.config.masking_do_full_word and len(candidates) > 0:
candidates[-1].append(i)
else:
candidates.append([i])
n_maskings = min(
self.config.masking_probability * target_sequence_length,
max(1, int(round(len(token_ids) * self.config.masking_probability))),
)
ngram_nvals = numpy.arange(self.config.masking_max_ngram, dtype=numpy.int64) + 1
# By default, the N-gram probabilities are inversely proportional to N
# e.g. N = 3
# -> P = array([0.54545455, 0.27272727, 0.18181818])
nprobs = 1.0 / ngram_nvals
nprobs = nprobs / nprobs.sum(keepdims=True)
if self.config.masking_use_longer_ngrams:
nprobs = nprobs[::-1]
# Create a nested list of depth 3
# layer 1: the candidate dimension
# layer 2: the N-gram dimension
# layer 3: the token dimension
candidate_ngrams = [
[candidates[idx : idx + n] for n in ngram_nvals] for idx in range(len(candidates))
]
numpy_random_state.shuffle(candidate_ngrams)
masked_token_ids = list(token_ids)
masked_positions_and_labels = []
masked_spans = []
masked_indices = set()
for candidate_idx in range(len(candidate_ngrams)):
n_ngrams = len(candidate_ngrams[candidate_idx])
# Stop when we hit our desired number of maskings
if len(masked_positions_and_labels) >= n_maskings:
break
# Do nothing for candidates with no ngrams
if not candidate_ngrams[candidate_idx]:
continue
# Choose the initial value of N
if self.config.masking_use_geometric_distribution:
# Sample N from a geometric distribution with p = 0.2 and clip
# i.e. SpanBERT
# -> https://arxiv.org/abs/1907.10529 (Section 3.1)
p = 0.2
n = min(numpy_random_state.geometric(p), self.config.masking_max_ngram)
else:
p = nprobs[:n_ngrams] / nprobs[:n_ngrams].sum(keepdims=True)
n = numpy_random_state.choice(ngram_nvals[:n_ngrams], p=p)
while True:
ngram_indices = sum(candidate_ngrams[candidate_idx][n - 1], [])
n = n - 1
# Success: masking this N-gram puts us below the desired number of maskings
if n_maskings >= len(masked_positions_and_labels) + len(ngram_indices):
skip_candidate = False
break
# Failure: no N-grams remain for this candidate
if n == 0:
skip_candidate = True
break
# Do nothing for candidates whose 1-gram is too long
if skip_candidate:
continue
# Do nothing for candidate indices which have already been masked
if any(map(lambda idx: idx in masked_indices, ngram_indices)):
continue
# Mask the tokens and record their original positions and values
for index in ngram_indices:
masked_indices.add(index)
mask = self._get_token_mask(numpy_random_state)
if mask is None:
masked_token_ids[index] = token_ids[index]
else:
masked_token_ids[index] = mask
masked_positions_and_labels.append((index, token_ids[index]))
masked_spans.append((ngram_indices, [token_ids[index] for index in ngram_indices]))
assert len(masked_positions_and_labels) <= n_maskings
numpy_random_state.shuffle(candidate_ngrams)
if self.config.masking_do_permutation:
n_swappings = n_maskings
permuted_indices = set()
for candidate_idx in range(len(candidate_ngrams)):
n_ngrams = len(candidate_ngrams[candidate_idx])
if len(permuted_indices) >= n_swappings:
break
# Do nothing for candidates with no ngrams
if not candidate_ngrams[candidate_idx]:
continue
p = nprobs[:n_ngrams] / nprobs[:n_ngrams].sum(keepdims=True)
n = numpy.random.choice(ngram_nvals[:n_ngrams], p=p)
while True:
ngram_indices = sum(candidate_ngrams[candidate_idx][n - 1], [])
n = n - 1
# Success: swapping this N-gram puts us below the desired number of swappings
if n_swappings >= len(permuted_indices) + len(ngram_indices):
skip_candidate = False
break
# Failure: no N-grams remain for this candidate
if n == 0:
skip_candidate = True
break
# Do nothing for candidates whose 1-gram is too long
if skip_candidate:
continue
# Do nothing for candidate indices which have already been masked or permuted
if any(
map(lambda idx: idx in masked_indices or idx in permuted_indices, ngram_indices)
):
continue
for index in ngram_indices:
permuted_indices.add(index)
assert len(permuted_indices) <= n_swappings
permuted_indices = sorted(permuted_indices)
permuted_indices_copy = list(permuted_indices)
numpy_random_state.shuffle(permuted_indices_copy)
masked_token_ids_copy = list(masked_token_ids)
for idx, idx_copy in zip(permuted_indices, permuted_indices_copy):
masked_token_ids[idx] = masked_token_ids_copy[idx_copy]
masked_positions_and_labels.append((idx, masked_token_ids_copy[idx]))
masked_positions_and_labels = sorted(masked_positions_and_labels, key=lambda x: x[0])
masked_positions = []
masked_labels = []
for position, label in masked_positions_and_labels:
masked_positions.append(position)
masked_labels.append(label)
masked_spans = sorted(masked_spans, key=lambda x: x[0][0])
return masked_token_ids, masked_positions, masked_labels, boundaries, masked_spans
@abstractmethod
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]:
pass
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import hashlib
import json
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional, Union
import numpy
import torch
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.utils import Split
LowLevelDataset = Union[IndexedDataset, Iterable]
class MegatronDataset(ABC, torch.utils.data.Dataset):
"""The highest level wrapper class from which all dataset classes should inherit
Args:
dataset (LowLevelDataset): The dataset around which to build the MegatronDataset
dataset_path (Optional[str]): The real path on disk to the dataset, for bookkeeping
indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The minimum number of samples to build from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indices Split
config (BlendedMegatronDatasetConfig): The config
"""
def __init__(
self,
dataset: LowLevelDataset,
dataset_path: Optional[str],
indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: BlendedMegatronDatasetConfig,
) -> None:
self.dataset = dataset
self.dataset_path = dataset_path
self.indices = indices
self.num_samples = num_samples
self.index_split = index_split
self.config = config
self.unique_identifiers = OrderedDict()
self.unique_identifiers["class"] = type(self).__name__
self.unique_identifiers["dataset_path"] = self.dataset_path
self.unique_identifiers["num_samples"] = self.num_samples
self.unique_identifiers["index_split"] = self.index_split.name
for attr in self._key_config_attributes():
self.unique_identifiers[attr] = getattr(self.config, attr)
self.unique_description = json.dumps(
self.unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
)
self.unique_description_hash = hashlib.md5(
self.unique_description.encode("utf-8")
).hexdigest()
self.built_anew_on_cache_miss = False
@staticmethod
def numel_low_level_dataset(low_level_dataset: LowLevelDataset) -> int:
"""Return the number of elements in the underlying low level dataset for the purpose of
segregating the train/valid/test split indices
It may be that the low level dataset can be split any number of ways, depending on the mid
level dataset it supports, which is why we define the "number of elements" function
separately from the __len__ function here in the mid level dataset class
Args:
low_level_dataset (LowLevelDataset): The underlying low level dataset
Returns:
int: The number of elements in the underlying low level dataset
"""
raise NotImplementedError
@staticmethod
def build_low_level_dataset(
dataset_path: str, config: BlendedMegatronDatasetConfig
) -> LowLevelDataset:
"""Build the low level dataset via a function to be called from within
BlendedMegatronDatasetBuilder.build_generic_dataset
It may be that the low level dataset spans any subset of train/valid/test splits, which is
why we define a static "build" function separately from the constructor in the mid level
dataset class
Args:
dataset_path (str): The real path on disk to the dataset
config (BlendedMegatronDatasetConfig): The dataset config
Returns:
LowLevelDataset: The low level dataset
"""
raise NotImplementedError
@staticmethod
def _key_config_attributes() -> List[str]:
"""Return all config attributes which contribute to uniquely identifying the dataset.
These attributes will be used to build a uniquely identifying string and MD5 hash which
will be used to cache/load dataset resources from run to run.
Returns:
List[str]: The key config attributes
"""
return ["random_seed", "sequence_length", "split", "split_matrix", "tokenizer"]
@abstractmethod
def __len__(self) -> int:
"""Return the length of the dataset
Returns:
int: See abstract implementation
"""
pass
@abstractmethod
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, numpy.ndarray]]:
"""Return from the dataset
Args:
idx (int): The index into the dataset
Returns:
Dict[str, Union[torch.Tensor, numpy.ndarray]]: See abstract implementation
"""
pass
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import json
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any
import numpy
class MegatronTokenizer(ABC):
"""Abstract class for tokenizer
Absent a config or class-specific tracking of which objects are uniquely identifying, we must
include all key word arguments as unique identifiers
Args:
tokenizer_paths (Tuple[str]): All tokenizer source paths or prefixes
tokenizer_options (Dict[str, Any]): All tokenizer options
"""
def __init__(self, *tokenizer_paths: str, **tokenizer_options: Any):
self.unique_identifiers = OrderedDict()
self.unique_identifiers["class"] = type(self).__name__
self.unique_identifiers["tokenizer_path"] = list(tokenizer_paths)
for option in tokenizer_options:
self.unique_identifiers[option] = str(tokenizer_options[option])
self.unique_description = json.dumps(self.unique_identifiers, indent=4)
super().__init__()
@abstractmethod
def tokenize(self, text: str) -> numpy.ndarray:
"""Convert text to embedding ids
Args:
text (str): The text to convert
Returns:
numpy.ndarray: The converted embedding ids
"""
pass
def detokenize(self, ids: numpy.ndarray) -> str:
"""Convert embedding ids to text
Args:
ids (numpy.ndarray): The ids to convert
Returns:
str: The converted text
Raises:
NotImplementedError: Non-abstract, optional method
"""
raise NotImplementedError("{} has no method 'detokenize'".format(type(self).__name__))
def offsets(self, ids: list[int], text: str) -> list[int]:
"""Convert embedding ids to text offsets
Args:
ids (list[int]): The ids to convert
text (str): The text to convert
Returns:
list[int]: The converted offsets
Raises:
NotImplementedError: Non-abstract, optional method
"""
raise NotImplementedError("{} has no method 'offsets'".format(type(self).__name__))
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token"""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token"""
pass
@property
@abstractmethod
def vocab_size(self):
"""The vocabulary size"""
pass
@property
def cls(self):
"""The CLS token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'cls'".format(type(self).__name__))
@property
def sep(self):
"""The SEP token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'sep'".format(type(self).__name__))
@property
def pad(self):
"""The PAD token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'pad'".format(type(self).__name__))
@property
def eod(self):
"""The EOD token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'eod'".format(type(self).__name__))
@property
def bos(self):
"""The BOS token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'bos'".format(type(self).__name__))
@property
def eos(self):
"""The EOS token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'eos'".format(type(self).__name__))
@property
def mask(self):
"""The MASK token id
Raises:
NotImplementedError: Non-abstract, optional attribute
"""
raise NotImplementedError("{} has no attribute 'mask'".format(type(self).__name__))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable, Dict
import torch
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
@dataclass
class MultimodalDatasetConfig(GPTDatasetConfig):
"""Configuration object for Megatron Core Multimodal datasets.
Note: This is unused at the moment and may be missing features. Follow-up changes will use this.
"""
image_h: int = None
"""Image height."""
image_w: int = None
"""Image width."""
# Function to preprocess the data sample to a format expected by a specific model. By default, do nothing.
preprocess_func: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = lambda x: x
"""Optional function to preprocess data samples for a specific model."""
def __post_init__(self) -> None:
super().__post_init__()
assert self.image_h is not None
assert self.image_w is not None
class MockMultimodalDataset(MockGPTDataset):
"""Mock multimodal dataset.
This is unused at the moment and may be missing features. Follow-up changes will use this.
"""
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Return a sample that contains a dummy image, text sequence and the associated labels and cost and attention masks.
Args:
idx (int): The integer seed for mock data generation.
Returns:
Dict[str, torch.Tensor]: The mock data.
"""
# Get a text sample.
sample = super().__getitem__(idx)
# Add mock input image.
sample["image"] = torch.zeros(
(3, self.config.image_h, self.config.image_w), dtype=torch.float32
)
# Run optional data preprocessing.
preprocess_func = self.config.preprocess_func
return preprocess_func(sample)
# Data Pipeline
## Data pre-processing
Data preprocessing is built around the following classes:
1. `IndexedDatasetBuilder`
2. `IndexedDataset`
At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details.
#### IndexedDatasetBuilder
The `IndexedDatasetBuilder` is capable of building and merging `IndexedDataset` instances.
#### IndexedDataset
The `IndexedDataset` class is the lowest-level data interface in Megatron Core. Internally, an `IndexedDataset` instance references two binaries: the data file (`.bin`) contains document/sequence data and the index file (`.idx`) contains document/sequence metadata.
The index file stores dataset-level metadata first:
- The index header, for backward compatibility
- The index version, for backward compatibility
- A numeric code corresponding to the data type used to write data to the data file
- The number of sequences in the dataset
- The number of documents in the dataset
The index file stores document-level and sequence-level metadata second:
- In order, the number of elements per sequence
- In order, the byte offset (pointer) per sequence
- In order, the consecutive sequence index range `[...)` per document
- In order, the mode per sequence (in the multimodal case)
## Data loading: construction
Building the data loaders is a distributed-aware process built around the following classes:
1. `BlendedMegatronDatasetConfig`
2. `BlendedMegatronDatasetBuilder`
3. `IndexedDataset`
3. `MegatronDataset`
4. `BlendedDataset`
See the class docstrings for more details.
#### BlendedMegatronDatasetConfig (extendable)
The `BlendedMegatronDatasetConfig` class parameterizes the `BlendedMegatronDatasetBuilder` and in turn the `MegatronDataset` and `BlendedDataset`.
Different training/inference regimes will require different extensions e.g. the `GPTDatasetConfig`
#### BlendedMegatronDatasetBuilder
The `BlendedMegatronDatasetBuilder` class builds the highest-level data interfaces in Megatron Core.
**NB:** All ranks should attempt to build the dataset via the `BlendedMegatronDatasetBuilder` or the program will hang. Which ranks follow through on their attempts can be controlled via the `BlendedMegatronDatasetConfig`.
#### IndexedDataset
The `IndexedDataset` class is the lowest-level data interface in Megatron Core.
The `IndexedDataset` should already exist on disk before attempting to build any of the high-level data interfaces.
#### MegatronDataset (extendable)
The `MegatronDataset` abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the `IndexedDataset`.
Different training/inference regimes will require different extensions e.g. the `GPTDataset`
#### BlendedDataset
The `BlendedDataset` class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MegatronDataset`.
The `BlendedDataset` is only necessary when a blend multiple data distributions, i.e. multiple `MegatronDataset` instances, should contribute to a certain dataset split. The blend can be controlled via the `BlendedMegatronDatasetConfig`.
## Data loading: implementation
### GPTDataset
The `GPTDataset` is parameterized by the following variables: the underlying `IndexedDataset` instance `indexed_dataset`, the split indices `indexed_indices` (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples `N`, the sequence length `S`, and the random seed `R`.
The `GPTDataset` creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index.
1. The document index _Do_idx_ is a 1-D array mapping from _i_ to document index of length `E * |indexed_indices|` where `E` corresponds to the minimum number of epochs such that `E * |indexed_indices| >= N`. The document index is shuffled according to `R`.
```
Given:
N = 15
indexed_indices = [5, 6, 7, 8, 9]
E = 3
Then, for example:
Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9]
```
2. The sample index _Sa_idx_ is a 2-D array mapping from _j_ to pairs of (_i_, _Do_idx_[ _i_ ] offset) of shape `[N + 1, 2]`. The rows _j_ and _j_ + 1 serve as the left and right bounds for the _j_-th sample.
```
Given:
S = 1024
Then, for example:
Sa_idx[0] = (0, 0)
Sa_idx[1] = (0, 1024) => Do_idx[0] has length greater than S
Sa_idx[2] = (1, 512) => Do_idx[0] has length 1536
Sa_idx[3] = (2, 0) => Do_idx[1] has length 1536
Sa_idx[4] = (5, 300) => Do_idx[2:5] are shorter documents relative to Do_idx[0:2]
Sa_idx[5] = (6, 24) => Do_idx[5] has length 1300
```
3. The shuffle index _Sh_idx_ is a 1-D array mapping from _k_ to _j_ of length `N`. The shuffle index is shuffled according to `R`.
```
Given
N = 10
Then, for example:
Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3]
```
To query the `GPTDataset` for the _k_-th sample we do the following
- Use the shuffle index to get the index _j_ into the sample index.
```
j = Sh_idx[k]
```
- Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document.
```
i, offset = Sa_idx[j]
i_next, offset_next = Sa_idx[j + 1]
```
- Use the document index to retrieve `S` tokens from consecutive (in the document index) documents.
```
sample = []
sample += indexed_dataset[Do_idx[i]][offset:]
if i != i_next:
sample += indexed_dataset[Do_idx[i + 1:i_next]]
sample += indexed_dataset[Do_idx[i_next]][:offset_next]
```
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `MegatronDataset.__init__` function.
### BlendedDataset
The `BlendedDataset` is parameterized by the following variables: the underlying `MegatronDataset` instances `D`, the weights `W` (one per dataset), and the size `S`. The `BlendedDataset` will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error.
The `BlendedDataset` creates two "blending" indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index.
1. The dataset index _Da_idx_ is a 1-D array mapping from _i_ to dataset index of length `S`.
```
Given
D = [d0, d1, d2]
W = [1/2, 1/4, 1/4]
S = 4
Then, for example:
Da_idx = [0, 1, 2, 0]
```
2. The dataset sample index _Sa_idx_ is a 1-D mapping from _i_ to the sample index for dataset _Da_idx[i]_ of length `S`.
```
Given
Da_idx = [0, 1, 2, 0]
Then, for example:
Sa_idx = [0, 0, 0, 1]
```
To query the `BlendedDataset` for the _k_-th sample we do the following
- Use the dataset index to retrieve the corresponding dataset from `D` and the dataset sample index to retrieve the corresponding sample from that dataset.
```
sample = D[Da_idx[k]][Sa_idx[k]]
```
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .config import RetroGPTChunkDatasets
from .query.multi_split_gpt_dataset import MultiSplitGPTDataset, MultiSplitGPTDatasetConfig
from .query.retro_dataset import get_retro_datasets
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- Embedder: Base class for all Bert embedders.
- RetroBertEmbedders: Container class for in-memory and on-disk embedders.
- RetroPreprocessingConfig: Configuration class for all of Retro preprocessing.
- RetroGPTChunkDatasets: Container class for train, valid, and test datasets.
- RetroTokenizers: Container class for GPT and Bert tokenizers.
"""
from .bert_embedders import Embedder, RetroBertEmbedders
from .config import RetroPreprocessingConfig
from .gpt_chunk_datasets import RetroGPTChunkDatasets
from .tokenizers import RetroTokenizers
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Container dataclass for holding both in-memory and on-disk Bert embedders."""
import abc
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
class Embedder(abc.ABC):
"""Base class for all Bert embedders.
All embedders should be able to embed either an entire text dataset (to a 2D
numpy array), or a single text string (to a 1D numpy array).
"""
@abc.abstractmethod
def embed_text_dataset(self, text_dataset: torch.utils.data.Dataset) -> np.ndarray:
"""Embed a text dataset.
Args:
text_dataset (torch.utils.data.Dataset): Text dataset to embed. Each sample of the text dataset should output a dict with a key 'text' and a string value.
Returns:
A 2D ndarray with shape (len(text_dataset), dimension(embedder)).
"""
@abc.abstractmethod
def embed_text(self, text: str) -> np.ndarray:
"""Embed a simple string of text.
Args:
text (str): A single text sample.
Returns:
A 1D ndarray with shape (dimensions(embedder),).
"""
@dataclass
class RetroBertEmbedders:
"""Container dataclass for in-memory and on-disk Bert embedders."""
disk: Embedder
mem: Embedder
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro preprocessing config."""
from dataclasses import dataclass
from megatron.core.transformer import TransformerConfig
from .bert_embedders import RetroBertEmbedders
from .gpt_chunk_datasets import RetroGPTChunkDatasets
from .tokenizers import RetroTokenizers
@dataclass
class RetroPreprocessingConfig(TransformerConfig):
"""Configuration object for Retro preprocessing.
*Note* : Arguments prefixed with '--retro-gpt-*' or '--retro-bert-*' are
included and named as such to more easily handle managing both models
running at the same time. Megatron is not optimized to run two models at
once, so this naming convention makes it clearer.
Args:
retro_project_dir (str): Retro project directory, which contains the preprocessed data for for pretraining. This directory is built during preprocessing (see tools/retro/README.md), and contains subdirectories for the chunk database and pretraining neighbors.
retro_tasks (str): Comma-separated list of tasks to run. Run entire preprocesing pipeline by using '--retro-tasks build'. Alternatively, run individual stages with tasks (in this order) 'db-build', 'index-build', or 'query-pretraining-neighbors'. For example, '--retro-tasks db-build,index-build,query-pretraining-neighbors' is equivalent to '--retro-tasks build'; or the argument can contain a subset of these tasks. Stages must always be run in the correct order (listed above).
retro_task_validate (float): If defined, validate a randomly sampled subset of the existing results of the given task. Each task implements a 'validate' method that is responsible for sampling a `retro_task_validate` fraction of the existing results, and then checking for bitwise equality with the current code base. (E.g., `--retro-task-validate 0.01`.)
retro_block_size (int): Number of chunks to process at a time when generating Bert embeddings and querying the search index. Partial results for each block are generally saved to disk in separate files.
retro_doc_block_size (int): Number of documents to processe at time when processing token datasets into chunk databases. The partial chunk database for each block is saved into a separate file.
retro_gpt_seed (int): Random seed used for python, numpy, pytorch, and cuda.
retro_gpt_data_path (str): Path to the training dataset. Accepted format: 1) a single data path, 2) multiple datasets in the form: dataset1-weight dataset1-path dataset2-weight dataset2-path ... It is used with --split when a single dataset used for all three: train, valid and test. It is exclusive to the other --*-data-path args.
retro_gpt_data_cache_path (str): Path to a directory to hold cached index files.
retro_gpt_split (str): Comma-separated list of proportions for training, validation, and test split. For example the split `90,5,5` will use 90%% of data for training, 5%% for validation and 5%% for test.
retro_gpt_train_samples (int): Total number of samples to train over all training runs.
retro_gpt_eval_interval (int): GPT evaluation interval.
retro_gpt_eval_iters (int): GPT evaluation iterations.
retro_gpt_tokenizer_type (str): GPT tokenizer type.
retro_gpt_tokenizer_model (str): GPT tokenizer model file.
retro_gpt_vocab_file (str): GPT vocab file.
retro_gpt_merge_file (str): GPT merge file.
retro_gpt_seq_length (int): GPT sequence length.
retro_gpt_global_batch_size (int): GPT global batch size.
retro_gpt_chunk_length (int): GPT chunk length.
retro_bert_tokenizer_type (str): Bert tokenizer type (for when using '--bert-embedder-type megatron').
retro_bert_vocab_file (str): Bert vocab file.
retro_bert_batch_size (int): Micro-batch size for processing Bert embeddings.
retro_bert_max_chunk_length (int): Maximum sequence length for Bert embeddings. (Named 'chunk' here in reference to these Bert sequences being converted from GPT chunks.)
retro_index_type (str): A 'faiss-base' index is a simple, un-optimized wrapper around a Faiss index. A 'faiss-par-add' index optimizes the 'add()' method by making it multi-node and multi-process, but with bit-wise equivalent results.
retro_index_str (str): Index string used for calling faiss.index_factory(). For example, 'IVF262144_HNSW32,Flat' or 'OPQ32_256,IVF4194304_HNSW32,PQ32'.
retro_index_ntrain (int): Number of database chunks to use for training the index. This value must be less or equal to the total number of chunks in the database.
retro_index_train_load_fraction (float): Fraction of sampled chunks to use for training the index. Useful when our total sampled embeddings use too much memory; lowering the load fraction is less costly than re-embedding a new sampled dataset from scratch.
retro_index_add_load_fraction (float): Fraction of database chunks to use for adding to the index. Useful when our total index size would use too much memory; lowering the load fraction is less costly than re-designing our token datasets.
retro_index_delete_training_embeddings (bool): Delete training embeddings for the search index. Useful for debugging.
retro_index_delete_added_codes (bool): Delete added codes for the search index. Useful for debugging.
retro_query_ef_search (int): Index ef-search parameter for Hierarchical Navigable Small Worlds (HNSW) during querying.
retro_query_nprobe (int): Index nprobe parameter for Inverted File (IVF) during querying.
retro_query_num_neighbors_query (int): Number of neighbors to retrieve when calling index.search().
retro_query_num_neighbors_save (int): Number of neighbors to save to disk after the index's returned neighbors. If longer than target value, neighbors truncated; and if shorter than target value, neighbors are padded with -1's.
retro_bert_embedders (RetroBertEmbedders): Set of Bert embedders used for embedding chunks. Contains entries: 1) 'mem' for an in-memory embedder, and 2) 'disk' for an embedder that saves results in blocks to disk.
retro_gpt_chunk_datasets (RetroGPTChunkDatasets): GPT datasets for 'train', 'valid', and 'test'.
retro_tokenizers (RetroTokenizers): GPT ('gpt') and Bert ('bert') tokenizers.
"""
# Basic.
retro_project_dir: str = None
retro_tasks: str = 'build'
retro_task_validate: float = None
retro_block_size: int = 100000
retro_doc_block_size: int = 100000
# GPT.
retro_gpt_seed: int = 1234
retro_gpt_data_path: list = None # basic list here, for parsing purposes
retro_gpt_data_cache_path: str = None
retro_gpt_split: str = '969,30,1'
retro_gpt_train_samples: int = None
retro_gpt_eval_interval: int = None
retro_gpt_eval_iters: int = None
retro_gpt_tokenizer_type: str = None
retro_gpt_tokenizer_model: str = None
retro_gpt_vocab_file: str = None
retro_gpt_merge_file: str = None
retro_gpt_seq_length: int = None
retro_gpt_global_batch_size: int = None
retro_gpt_chunk_length: int = 64
# Bert.
retro_bert_tokenizer_type: str = None
retro_bert_vocab_file: str = None
retro_bert_batch_size: int = 128
retro_bert_max_chunk_length: int = 256
# Index.
retro_index_type: str = 'faiss-par-add'
retro_index_str: str = None
retro_index_ntrain: int = None
retro_index_train_load_fraction: float = 1.0
retro_index_add_load_fraction: float = 1.0
retro_index_delete_training_embeddings: bool = True
retro_index_delete_added_codes: bool = True
# Query.
retro_query_ef_search: int = 256
retro_query_nprobe: int = 65536
retro_query_num_neighbors_query: int = 200
retro_query_num_neighbors_save: int = 20
# Tools.
retro_bert_embedders: RetroBertEmbedders = None
retro_gpt_chunk_datasets: RetroGPTChunkDatasets = None
retro_tokenizers: RetroTokenizers = None
def __post_init__(self) -> None:
"""Validate Retro config."""
# Validate required attributes.
assert self.retro_project_dir is not None
assert self.retro_tasks is not None
assert self.retro_gpt_data_path is not None or self.retro_gpt_data_cache_path is not None
assert self.retro_gpt_train_samples is not None
assert self.retro_gpt_eval_interval is not None
assert self.retro_gpt_eval_iters is not None
assert self.retro_gpt_tokenizer_type is not None
assert self.retro_gpt_tokenizer_model is not None or (
self.retro_gpt_vocab_file is not None and self.retro_gpt_merge_file is not None
)
assert self.retro_gpt_seq_length is not None
assert self.retro_gpt_global_batch_size is not None
assert self.retro_bert_tokenizer_type is not None
assert self.retro_bert_vocab_file is not None
assert self.retro_index_str is not None
assert self.retro_index_ntrain is not None
# Split retro tasks.
self.retro_tasks = self.retro_tasks.split(",")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Container dataclass for GPT chunk datasets (train, valid, and test)."""
from dataclasses import dataclass
@dataclass
class RetroGPTChunkDatasets:
"""Container dataclass for GPT chunk datasets."""
# Each dict contains 'dataset', 'neighbor_dir', and 'num_active_chunks'.
train: dict = None
valid: dict = None
test: dict = None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Container class for GPT and Bert tokenizers."""
from dataclasses import dataclass
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
@dataclass
class RetroTokenizers:
"""Container class for GPT and Bert tokenizers."""
gpt: MegatronTokenizer = None
bert: MegatronTokenizer = None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- build_db: Build a chunk database from a list of indexed datasets.
"""
from .build import build_db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Build a chunk database from a list of indexed datasets.
Building a chunk database consists of.
- Breaking each document of each indexed dataset into consecutive
retro_gpt_chunk_length chunks.
- Re-tokenize each chunk into Bert, and discard any chunks with empty Bert
tokens.
- Save chunk offsets to disk for each indexed dataset.
"""
import glob
import os
import types
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import (
extract_data_config,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .utils import (
get_indexed_dataset_infos,
get_indexed_dataset_infos_path,
get_individual_chunk_db,
get_individual_db_dir,
get_individual_db_paths,
get_individual_doc_offsets,
get_merged_db_path_map,
init_indexed_dataset_infos,
load_indexed_datasets,
save_indexed_dataset_infos,
)
def build_partial_db(
config: types.SimpleNamespace,
dataset_idx: int,
n_datasets: int,
indexed_dataset: IndexedDataset,
block_id: int,
n_blocks: int,
block: dict,
proc_id: int,
n_procs: int,
) -> Tuple[int, list, list, dict]:
"""Process a document index range of the indexed dataset.
The chunk database is built in parallel blocks, since de-tokenizing &
re-tokenizing for Bert-length computation is expensive. This method
iterates each document and extracts sequential 'chunk-length' sequences
from each document.
Args:
config (types.SimpleNamespace): Subset of Retro config, containing 'chunk_length', 'gpt_eod', 'gpt_detokenize', 'bert_tokenize', and 'task_validate'.
dataset_idx (int): Index of this dataset out of all blended datasets.
n_datasets (int): Total number of blended datasets.
indexed_dataset (IndexedDataset): Indexed dataset to be chunked.
block_id (int): Block index out of all blocks to be processed.
n_blocks (int): Total number of blocks to be processed.
block (dict): Range information such as start/end points for chunking idnexed dataset.
proc_id (int): Process ID for tracking parallel process order.
n_procs (int): Total number of parallel processes.
Returns:
A tuple containing:
- Process ID.
- List of valid chunks.
- List of invalid chunks (i.e., chunks that converted to empty Bert embeddings.).
- Dict mapping document ID to number of valid chunks.
"""
# Document start/end indexes.
doc_range = block["range"]
n_docs = doc_range[1] - doc_range[0]
n_docs_per_proc = int(np.ceil(n_docs / n_procs))
doc_start_id = doc_range[0] + proc_id * n_docs_per_proc
doc_end_id = min(doc_range[1], doc_start_id + n_docs_per_proc)
# Print progress.
progress_proc_ids = set(range(n_procs)) if torch.distributed.get_rank() == 0 else set()
if proc_id in progress_proc_ids:
log_retro_rank_0(
" > building partial chunk db, proc %d / %d, docs %d:%d / %d."
% (proc_id, n_procs, doc_start_id, doc_end_id, n_docs)
)
# Progress bars (snapshot of overall progress).
doc_id_iter = range(doc_start_id, doc_end_id)
pbar = (
tqdm(doc_id_iter, "parse doc chunks", miniters=len(doc_id_iter) // 20)
if proc_id in progress_proc_ids
else doc_id_iter
)
# Iterate documents & parse chunks.
chunk_db_valid: List[Tuple] = []
chunk_db_invalid: List[Tuple] = []
doc_size_map = {}
for doc_id in pbar:
# Progress description.
try:
pbar.set_description(
"%sds %d / %d, block %d / %d, proc %d / %d."
% (
"" if config.task_validate is None else "[validate] ",
dataset_idx,
n_datasets,
block_id,
n_blocks,
proc_id,
n_procs,
)
)
except Exception:
pass
# Remove EOD token.
doc = indexed_dataset.get(doc_id)
if doc[-1].item() == config.gpt_eod:
doc = doc[:-1]
doc_len = len(doc)
# Chunk start/end indexes.
chunk_start_idxs = list(range(0, doc_len, config.chunk_length))
chunk_end_idxs = [min(doc_len, s + config.chunk_length) for s in chunk_start_idxs]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
doc_size_map[doc_id] = 0
for i, chunk_start_idx in enumerate(chunk_start_idxs):
# Re-tokenize.
chunk_end_idx = chunk_end_idxs[i]
gpt_token_ids = indexed_dataset.get(
idx=doc_id, offset=chunk_start_idx, length=chunk_end_idx - chunk_start_idx
)
text = config.gpt_detokenize(gpt_token_ids.tolist())
bert_token_ids = config.bert_tokenize(text)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
if len(bert_token_ids) == 0:
_chunk_db = chunk_db_invalid
else:
_chunk_db = chunk_db_valid
doc_size_map[doc_id] += 1
_chunk_db.append((doc_id, chunk_start_idx, chunk_end_idx, len(bert_token_ids)))
return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map
def build_block_db(
config: RetroPreprocessingConfig,
dataset_idx: int,
n_datasets: int,
indexed_dataset: IndexedDataset,
n_procs: int,
executor: ProcessPoolExecutor,
n_missing_blocks: int,
block_idx: int,
block: dict,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Split each document within block into consecutive retro_gpt_chunk_length size chunks.
Args:
config (RetroPreprocessingConfig): For DB building, we make use of attributes 'chunk_length', 'gpt_eod', 'gpt_detokenize', 'bert_tokenize', and 'task_validate'.
dataset_idx (int): Index of this dataset out of all blended datasets.
n_datasets (int): Total number of blended datasets.
indexed_dataset (IndexedDataset): Indexed dataset to be chunked.
n_procs (int): Total number of parallel processes.
executor (ProcessPoolExecutor): Executor for launching parallel processes.
n_missing_blocks (int): Total number of blocks to be processed.
block_idx (int): Block index out of all blocks to be processed.
block (dict): Range information such as start/end points for chunking idnexed dataset.
Returns:
A tuple containing:
- List of valid chunks.
- List of invalid chunks (i.e., chunks that converted to empty Bert embeddings.).
- Dict mapping document ID to number of valid chunks.
"""
# Build partial dbs.
log_retro_rank_0(' > build partial dbs.')
futures = []
for proc_id in range(n_procs): # not true process id
futures.append(
executor.submit(
build_partial_db,
types.SimpleNamespace(
chunk_length=config.retro_gpt_chunk_length,
gpt_eod=config.retro_tokenizers.gpt.eod,
gpt_detokenize=config.retro_tokenizers.gpt.detokenize,
bert_tokenize=config.retro_tokenizers.bert.tokenize,
task_validate=config.retro_task_validate,
),
dataset_idx,
n_datasets,
indexed_dataset,
block_idx,
n_missing_blocks,
block,
proc_id,
n_procs,
)
)
partial_chunk_dbs = []
for future in as_completed(futures):
partial_chunk_dbs.append(future.result())
# Concatenate chunks.
partial_chunk_dbs.sort(key=lambda item: item[0]) # sort by proc_id
chunk_db_valid = [
item for partial_chunk_db in partial_chunk_dbs for item in partial_chunk_db[1]
]
chunk_db_invalid = [
item for partial_chunk_db in partial_chunk_dbs for item in partial_chunk_db[2]
]
# Convert to numpy.
log_retro_rank_0(' > converting chunk db to numpy.')
chunk_db_valid = np.array(chunk_db_valid, dtype="uint32")
chunk_db_invalid = np.array(chunk_db_invalid, dtype="uint32")
# Document offsets.
doc_sizes = [
(d, s) for partial_chunk_db in partial_chunk_dbs for d, s in partial_chunk_db[3].items()
]
doc_sizes.sort(key=lambda item: item[0])
doc_offsets = np.cumsum([item[1] for item in doc_sizes]).astype("uint64")
doc_offsets = np.stack(
(np.array([item[0] for item in doc_sizes], dtype="uint64"), doc_offsets), axis=1
)
return chunk_db_valid, chunk_db_invalid, doc_offsets
def save_block_db(
block: dict, chunk_db_valid: np.ndarray, chunk_db_invalid: np.ndarray, doc_offsets: np.ndarray
) -> None:
"""Save block of chunked tokens to disk. These blocks are later used for
training and adding to the vector index.
Args:
block (dict): Range information such as start/end points for chunking idnexed dataset.
chunk_db_valid (np.ndarray): Array of valid chunk indexes.
chunk_db_invalid (np.ndarray): Array of invalid chunk indexes.
doc_offsets (np.ndarray): Array of document offsets by chunks.
"""
log_retro_rank_0(" > saving individual db.")
with h5py.File(block["path"], "w") as f:
dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid)
dset = f.create_dataset("doc_offsets", data=doc_offsets)
def build_individual_db(
config: RetroPreprocessingConfig, dataset_idx: int, n_datasets: int, dataset_info: dict
) -> None:
"""Process a single indexed dataset & extract chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
dataset_idx (int): Dataset index within blended dataset.
n_datasets (int): Total number of datasets within blended dataset.
dataset_info (dict): Metadata for dataset (see `save_indexed_dataset_infos()` in `utils.py` for more detail).
"""
# Make directory.
db_dir = get_individual_db_dir(config.retro_project_dir, dataset_info["prefix"])
retro_makedir(config, db_dir)
# Indexed dataset.
indexed_dataset = dataset_info["dataset"]
# Missing DB blocks (split by documents).
blocks = get_blocks_by_rank(
db_dir,
len(indexed_dataset),
config.retro_doc_block_size,
validate=lambda f: f["chunks_valid"].shape == (0,) or f["chunks_valid"].shape[1] == 4,
sample=config.retro_task_validate,
)
if config.retro_task_validate is None:
active_blocks = blocks.missing
else:
assert blocks.n_missing_world == 0
active_blocks = blocks.existing
# Prevent missing-path-write race condition.
torch.distributed.barrier()
# Nothing to do?
if config.retro_task_validate is None and not active_blocks:
return
# Num processes.
if blocks.n_missing_world == 1:
n_procs = 128
elif blocks.n_missing_world <= 2:
n_procs = 64
elif blocks.n_missing_world <= 4:
n_procs = 32
elif blocks.n_missing_world <= 8:
n_procs = 16
else:
n_procs = 8
# Process documents in parallel.
with ProcessPoolExecutor(max_workers=n_procs) as executor:
for block_idx, block in enumerate(active_blocks):
if block is not None:
# Build block DB.
chunk_db_valid, chunk_db_invalid, doc_offsets = build_block_db(
config=config,
dataset_idx=dataset_idx,
n_datasets=n_datasets,
indexed_dataset=indexed_dataset,
n_procs=n_procs,
executor=executor,
n_missing_blocks=len(active_blocks),
block_idx=block_idx,
block=block,
)
if config.retro_task_validate is None:
# Save block DB.
save_block_db(
block=block,
chunk_db_valid=chunk_db_valid,
chunk_db_invalid=chunk_db_invalid,
doc_offsets=doc_offsets,
)
else:
# Load existing block DB.
with h5py.File(block["path"]) as f:
existing_chunks_valid = np.copy(f["chunks_valid"])
existing_chunks_invalid = np.copy(f["chunks_invalid"])
existing_doc_offsets = np.copy(f["doc_offsets"])
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_chunks_valid, chunk_db_valid)
assert np.array_equal(existing_chunks_invalid, chunk_db_invalid)
assert np.array_equal(existing_doc_offsets, doc_offsets)
# Wait for all ranks to finish block.
log_retro_rank_0(" > waiting for all ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished saving individual db.")
def build_individual_dbs(
config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict]
) -> None:
"""Iterate each indexed dataset & process its chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset.
"""
# Build individual DBs.
log_retro_rank_0(" > build individual chunk dbs.")
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
# Progress.
log_retro_rank_0(
" > building individual db, dataset %d / %d ... '%s'."
% (ds_idx, len(indexed_dataset_infos), ds_info["prefix"])
)
# Process single dataset.
build_individual_db(config, ds_idx, len(indexed_dataset_infos), ds_info)
def update_chunk_counts(
config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict]
) -> None:
"""Set n_chunks_train & n_chunks sampled for each individual DB.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.).
"""
if torch.distributed.get_rank() != 0:
return
# Data ratio sum (for setting index training chunks).
data_ratio_sum = sum([d["ratio"] for d in indexed_dataset_infos])
# Training split size (split at document level).
train_fraction = float(extract_data_config(config).split.split(",")[0]) / 100
assert train_fraction > 0 and train_fraction <= 1
# Set n_chunks (including n_chunks_sampled for unambiguity).
log_retro_rank_0(" > compute n_chunks.")
for ds_index, ds_info in enumerate(indexed_dataset_infos):
db_paths = get_individual_db_paths(config.retro_project_dir, ds_info["prefix"])
# Update counts.
ds_info["n_docs"] = len(ds_info["dataset"].document_indices) - 1
ds_info["n_docs_train"] = int(train_fraction * ds_info["n_docs"])
ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
ds_info["n_chunks_train"] = 0
ds_info["n_chunks_invalid"] = 0
for db_path in tqdm(
db_paths, "%d/%d, %s" % (ds_index, len(indexed_dataset_infos), ds_info["prefix"])
):
with h5py.File(db_path, "r") as f:
ds_info["n_chunks"] += len(f["chunks_valid"])
ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
ds_info["n_chunks_train"] += (
(np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]).sum().item()
)
ds_info["n_chunks_sampled"] = int(
config.retro_index_ntrain * ds_info["ratio"] / data_ratio_sum
)
# Verify counts.
assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], "n_train (%d) > n_total (%d)." % (
ds_info["n_chunks_train"],
ds_info["n_chunks"],
)
assert (
ds_info["n_chunks_sampled"] <= ds_info["n_chunks_train"]
), "n_sampled (%d) > n_train (%d)." % (
ds_info["n_chunks_sampled"],
ds_info["n_chunks_train"],
)
def merge_dbs(project_dir: str, indexed_dataset_infos: List[Dict], db_type: str) -> None:
"""Merge individual DBs into single DB.
Args:
project_dir (str): Retro project dir.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.).
db_type (str): DB type (e.g., 'sampled', 'train', or 'valid').
"""
if torch.distributed.get_rank() != 0:
return
log_retro_rank_0(" > build %s chunk db." % db_type)
# Count chunks.
if db_type == "sampled":
n_chunks_key = "n_chunks_sampled"
n_docs_key = None
elif db_type == "train":
n_chunks_key = "n_chunks_train"
n_docs_key = "n_docs_train"
elif db_type == "valid":
n_docs_key = None
else:
raise Exception("handle db_type '%s'." % db_type)
if db_type == "valid":
n_chunks = sum(m["n_chunks"] - m["n_chunks_train"] for m in indexed_dataset_infos)
else:
n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
n_docs = None if n_docs_key is None else sum(m[n_docs_key] for m in indexed_dataset_infos)
# DB path.
db_path = get_merged_db_path_map(project_dir)[db_type]
# Delete existing chunk db if incorrect size.
if os.path.exists(db_path):
try:
f = h5py.File(db_path)
n_alloc = len(f["chunks"]) # total allocated
n_written = f["n_written"][0].item() # total written
f.close()
if n_chunks != n_alloc or n_chunks != n_written:
os.remove(db_path)
except Exception as e:
if isinstance(e, OSError):
os.remove(db_path)
elif isinstance(e, KeyError):
f.close()
os.remove(db_path)
else:
raise e
# Build merged chunk db.
if not os.path.exists(db_path):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
f = h5py.File(db_path, "w")
# Initialize output arrays.
merged_chunk_db: np.ndarray = f.create_dataset("chunks", (n_chunks, 5), dtype="uint32")
merged_doc_offsets: np.ndarray = (
None
if n_docs_key is None
else f.create_dataset("doc_offsets", (n_docs, 3), dtype="uint64")
)
n_written = f.create_dataset("n_written", (1,), dtype="uint64")
n_written[0] = 0
# Iterate indexed datasets & collect chunks.
chunk_start_index = 0
doc_start_index = 0
doc_start_offset = 0
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
log_retro_rank_0(
" > merging dbs; '%s', dataset %d / %d ... '%s'."
% (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"])
)
individual_chunk_db: np.ndarray = get_individual_chunk_db(project_dir, ds_idx, ds_info)
individual_doc_offsets: np.ndarray = (
None
if n_docs_key is None
else get_individual_doc_offsets(project_dir, ds_idx, ds_info)
)
if db_type == "valid":
individual_chunk_db = individual_chunk_db[ds_info["n_chunks_train"] :]
if n_docs_key is None:
individual_doc_offsets = None
else:
train_doc_offset = individual_doc_offsets[ds_info["n_docs_train"] - 1, 2]
individual_doc_offsets = np.copy(
individual_doc_offsets[ds_info["n_docs_train"] :]
)
individual_doc_offsets[:, 2] -= train_doc_offset
log_retro_rank_0("~~~")
log_retro_rank_0(individual_doc_offsets)
log_retro_rank_0(train_doc_offset)
raise Exception("test me.")
else:
individual_chunk_db = individual_chunk_db[: ds_info[n_chunks_key]]
individual_doc_offsets = (
None
if n_docs_key is None
else np.copy(individual_doc_offsets[: ds_info[n_docs_key]])
)
merged_chunk_db[chunk_start_index : chunk_start_index + len(individual_chunk_db)] = (
individual_chunk_db
)
chunk_start_index += len(individual_chunk_db)
n_written[0] = chunk_start_index
if n_docs_key is not None:
individual_doc_offsets[:, 2] += doc_start_offset
doc_end_index = doc_start_index + individual_doc_offsets.shape[0]
merged_doc_offsets[doc_start_index:doc_end_index] = individual_doc_offsets
doc_start_index = doc_end_index
doc_start_offset = individual_doc_offsets[-1, 2].item()
f.close()
def build_merged_dbs(project_dir: str, indexed_dataset_infos: List[Dict]) -> None:
"""Merge individual dataset components into single database.
This method merges databases for DB types:
- 'sampled': used for training the vector index.
- 'train': used for adding to the trained vector index.
- 'valid': can be used for validating/testing the vector index.
Args:
project_dir (str): Retro project dir.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.).
"""
merge_dbs(project_dir, indexed_dataset_infos, "sampled")
merge_dbs(project_dir, indexed_dataset_infos, "train")
merge_dbs(project_dir, indexed_dataset_infos, "valid")
def build_db(config: RetroPreprocessingConfig) -> None:
"""Extract token chunks from each indexed dataset.
Iterate each document of each indexed dataset, extract that document's chunks, and save to a 'DB' (hdf5 file).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
project_dir = config.retro_project_dir
# Indexed dataset info.
if config.retro_task_validate is None:
indexed_dataset_infos = init_indexed_dataset_infos(config)
else:
indexed_dataset_infos = get_indexed_dataset_infos(config.retro_project_dir)
# Build individual dbs.
build_individual_dbs(config, indexed_dataset_infos)
# If validating, return here.
if config.retro_task_validate is not None:
return
# Single-process going forward.
if torch.distributed.get_rank() != 0:
return
# Update n_chunks & save indexed dataset infos.
if not os.path.exists(get_indexed_dataset_infos_path(project_dir)):
update_chunk_counts(config, indexed_dataset_infos)
save_indexed_dataset_infos(project_dir, indexed_dataset_infos)
indexed_dataset_infos = get_indexed_dataset_infos(project_dir)
# Builded merged dbs.
build_merged_dbs(project_dir, indexed_dataset_infos)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""A DBDataset is for iterating the chunks of the chunk database.
This dataset is used for both training a vector index, and adding vectors to a
trained index.
"""
from typing import List
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.indexed_dataset import IndexedDataset
class DBDataset(torch.utils.data.Dataset):
"""Dataset for iterating chunks.
Args:
db_path (str): Path of HDF5-format chunk database.
indexed_datasets (List[IndexedDataset]): Indexed datasets used to build database.
chunks (np.ndarray): Array of chunk indexes, for indexing into indexed datasets. Format [dataset_idx, doc_id, start_idx, end_idx, bert_length].
chunk_length (int): Max GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
"""
def __init__(
self,
db_path: str,
indexed_datasets: List[IndexedDataset],
chunks: np.ndarray,
chunk_length: int,
eod_token_id: int,
):
assert chunks.shape[1] == 5, (
"expected 5 columns (dataset_idx, "
"doc_idx, token_start_idx, token_end_idx, bert_chunk_length); "
"found %d columns." % chunks.shape[1]
)
self.db_path = db_path
self.indexed_datasets = indexed_datasets
self.chunks = chunks
self.doc_chunk_map = None
self.max_chunk_length = chunk_length
self.eod_token_id = eod_token_id
def __len__(self) -> int:
"""Length of DB dataset.
Returns:
Number of chunks contained in the dataset.
"""
return self.chunks.shape[0]
def __getitem__(self, chunk_id: int) -> dict:
"""DB dataset sample.
Args:
chunk_id (int): Index of chunk within dataset.
Returns:
A dict containing:
- 'doc_id': Document index within indexed dataset.
- 'text': GPT token IDs.
"""
# Chunk start/end indexes.
indexed_dataset_id, doc_id, token_start_idx, token_end_idx, _ = [
value.item() for value in self.chunks[chunk_id]
]
chunk_length = token_end_idx - token_start_idx
indexed_dataset = self.indexed_datasets[indexed_dataset_id]
# Chunk token ids.
token_ids = indexed_dataset.get(doc_id, offset=token_start_idx, length=chunk_length)
# Extend chunks to max_chunk_length by padding with EOD tokens.
if chunk_length != self.max_chunk_length:
assert chunk_length < self.max_chunk_length, "invalid chunk len."
token_ids = token_ids.tolist()
token_ids += [self.eod_token_id] * (self.max_chunk_length - chunk_length)
return {"doc_id": doc_id, "text": np.array(token_ids, dtype=np.int64)}
def load_doc_tuples(self) -> None:
"""Load the dataset & document ids.
Load the dataset id & document id of each chunk in the database, to
be used for causality filtering during querying.
"""
self.doc_tuples = np.zeros(shape=(len(self), 2), dtype="uint32")
block_size = int(1e6)
for start_idx in tqdm(
range(0, len(self), block_size),
"load doc tuples",
miniters=(len(self) // block_size) // 10,
disable=torch.distributed.get_rank() != 0,
):
end_idx = min(len(self), start_idx + block_size)
self.doc_tuples[start_idx:end_idx] = self.chunks[start_idx:end_idx, :2]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment