# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass from typing import Dict, List, Optional, Union import numpy from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.masked_dataset import ( MaskedWordPieceDataset, MaskedWordPieceDatasetConfig, ) from megatron.core.datasets.utils import Split @dataclass class BERTMaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig): """Configuration object for Megatron Core BERT WordPiece datasets""" classification_head: bool = None """Option to perform the next sequence prediction during sampling""" def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() assert self.classification_head is not None class BERTMaskedWordPieceDataset(MaskedWordPieceDataset): """The BERT dataset that assumes WordPiece tokenization Args: indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset dataset_path (str): The real path on disk to the dataset, for bookkeeping indexed_indices (numpy.ndarray): The set of the documents indices to expose num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch. index_split (Split): The indexed_indices Split config (BERTMaskedWordPieceDatasetConfig): The config """ def __init__( self, indexed_dataset: IndexedDataset, dataset_path: str, indexed_indices: numpy.ndarray, num_samples: Optional[int], index_split: Split, config: BERTMaskedWordPieceDatasetConfig, ) -> None: super().__init__( indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config ) self.token_lookup = list(self.config.tokenizer.inv_vocab.keys()) # Account for the single and two token ids self.sample_index = self._build_sample_index( self.config.sequence_length - 3, 2 if self.config.classification_head else 1 ) @staticmethod def _key_config_attributes() -> List[str]: """Inherited method implementation Returns: List[str]: The key config attributes """ return super( BERTMaskedWordPieceDataset, BERTMaskedWordPieceDataset )._key_config_attributes() + ["classification_head"] def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: """Abstract method implementation Args: idx (int): The index into the dataset Returns: Dict[str, Union[int, numpy.ndarray]]: The """ idx_beg, idx_end, target_sequence_length = self.sample_index[idx] sample = [self.dataset[i] for i in range(idx_beg, idx_end)] numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32) assert target_sequence_length <= self.config.sequence_length # Split the sample into contiguous subsegments A and B pivot = len(sample) is_next_random = False if self.config.classification_head: assert len(sample) > 1, "the sample must contain at least two sentences" pivot = 1 if len(sample) >= 3: pivot = numpy_random_state.randint(low=1, high=len(sample)) is_next_random = numpy_random_state.random() < 0.5 split_A = [] for sample_a in sample[:pivot]: split_A.extend(sample_a) split_B = [] for sample_b in sample[pivot:]: split_B.extend(sample_b) if is_next_random: split_A, split_B = split_B, split_A # Trim the subsegments from either end to a desired joint length length_A = len(split_A) length_B = len(split_B) if length_A + length_B <= target_sequence_length: truncated = False else: while length_A + length_B > target_sequence_length: split = split_A if length_A > length_B else split_B if numpy_random_state.random() < 0.5: del split[0] else: del split[-1] length_A = len(split_A) length_B = len(split_B) truncated = True # Merge the subsegments and create the token assignment labels tokens = [self.config.tokenizer.cls, *split_A, self.config.tokenizer.sep] assignments = [0 for _ in range(1 + len(split_A) + 1)] if split_B: tokens += [*split_B, self.config.tokenizer.sep] assignments += [1 for _ in range(len(split_B) + 1)] # Masking tokens, masked_positions, masked_labels, _, _ = self._create_masked_lm_predictions( tokens, target_sequence_length, numpy_random_state ) # Pad the sequences and convert to NumPy length_toks = len(tokens) length_pads = self.config.sequence_length - length_toks assert length_pads >= 0 tokens = numpy.array(tokens, dtype=numpy.int64) tokens = numpy.pad(tokens, (0, length_pads), constant_values=self.config.tokenizer.pad) assignments = numpy.array(assignments, dtype=numpy.int64) assignments = numpy.pad( assignments, (0, length_pads), constant_values=self.config.tokenizer.pad ) # Get the padding mask mask_pads = numpy.ones(length_toks, dtype=numpy.int64) mask_pads = numpy.pad( mask_pads, (0, length_pads), constant_values=self.config.tokenizer.pad ) # Mask the labels labels = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) - 1 labels[masked_positions] = masked_labels # Get the loss mask mask_loss = numpy.zeros(self.config.sequence_length, dtype=numpy.int64) mask_loss[masked_positions] = 1 return { "text": tokens, "types": assignments, "labels": labels, "is_random": int(is_next_random), "padding_mask": mask_pads, "loss_mask": mask_loss, "truncated": int(truncated), } def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]: """Abstract method implementation 80% of the time, replace the token id with mask token id. 10% of the time, replace token id with a random token id from the vocabulary. 10% of the time, do nothing. Args: numpy_random_state (RandomState): The NumPy random state Returns: Optional[int]: The replacement token id or None """ if numpy_random_state.random() < 0.8: return self.config.tokenizer.mask else: if numpy_random_state.random() >= 0.5: return self.token_lookup[numpy_random_state.randint(0, len(self.token_lookup))] return None