# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from collections import deque from dataclasses import dataclass, field from typing import Dict, List, Optional, Union import numpy from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.masked_dataset import ( MaskedWordPieceDataset, MaskedWordPieceDatasetConfig, ) from megatron.core.datasets.utils import Split @dataclass class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig): """Configuration object for Megatron Core T5 WordPiece datasets NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core. """ sequence_length_encoder: Optional[int] = field(init=False, default=None) """A sequence_length alias and the sequence length for the encoder""" sequence_length_decoder: int = None """The sequence length for the decoder""" def __post_init__(self) -> None: """Do asserts and set fields post init """ super().__post_init__() self.sequence_length_encoder = self.sequence_length assert self.sequence_length_encoder is not None assert self.sequence_length_decoder is not None assert len(self.tokenizer.additional_special_tokens_ids) > 0 class T5MaskedWordPieceDataset(MaskedWordPieceDataset): """The T5 dataset that assumes WordPiece tokenization Args: indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset dataset_path (str): The real path on disk to the dataset, for bookkeeping indexed_indices (numpy.ndarray): The set of the documents indices to expose num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch. index_split (Split): The indexed_indices Split config (T5MaskedWordPieceDatasetConfig): The config """ def __init__( self, indexed_dataset: IndexedDataset, dataset_path: str, indexed_indices: numpy.ndarray, num_samples: Optional[int], index_split: Split, config: T5MaskedWordPieceDatasetConfig, ) -> None: super().__init__( indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config ) self.token_lookup = list(self.config.tokenizer.inv_vocab.keys()) # Account for the single and single token ids self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1) @staticmethod def _key_config_attributes() -> List[str]: """Inherited method implementation Returns: List[str]: The key config attributes """ return super( T5MaskedWordPieceDataset, T5MaskedWordPieceDataset )._key_config_attributes() + ["sequence_length_decoder",] def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]: """Abstract method implementation Args: idx (int): The index into the dataset Returns: Dict[str, Union[int, numpy.ndarray]]: The """ idx_beg, idx_end, target_sequence_length = self.sample_index[idx] sample = [self.dataset[i] for i in range(idx_beg, idx_end)] numpy_random_state = numpy.random.RandomState( seed=(self.config.random_seed + idx) % 2 ** 32 ) assert target_sequence_length <= self.config.sequence_length # Flatten the sample into a list of tokens tokens = [token for sentence in sample for token in sentence] # Truncate the list of tokens to a desired length truncated = len(tokens) > target_sequence_length tokens = tokens[:target_sequence_length] # Masking (tokens, _, _, _, masked_spans,) = self._create_masked_lm_predictions( tokens, target_sequence_length, numpy_random_state ) # Prepare the encoder input and decoder input and output sentinels = deque(self.config.tokenizer.additional_special_tokens_ids) encoder_input = [] decoder_input = [self.config.tokenizer.bos] decoder_output = [] idx_beg = 0 for indices, labels in masked_spans: sentinel = sentinels.popleft() # set the end index idx_end = indices[0] encoder_input.extend(tokens[idx_beg:idx_end]) encoder_input.append(sentinel) decoder_input.append(sentinel) decoder_input.extend(labels) decoder_output.append(sentinel) decoder_output.extend(labels) # set the start index idx_beg = indices[-1] + 1 encoder_input.extend(tokens[idx_beg:]) decoder_output.append(self.config.tokenizer.eos) # Pad the sequences and convert to NumPy length_toks_encoder = len(encoder_input) length_toks_decoder = len(decoder_input) length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder assert length_pads_encoder >= 0 assert length_pads_decoder >= 0 encoder_input = numpy.array(encoder_input, dtype=numpy.int64) encoder_input = numpy.pad( encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad ) decoder_input = numpy.array(decoder_input, dtype=numpy.int64) decoder_input = numpy.pad( decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad ) # Create attention and history masks mask_encoder = self._make_attention_mask(encoder_input, encoder_input) mask_encoder_decoder = self._make_attention_mask(decoder_input, encoder_input) mask_decoder = self._make_attention_mask(decoder_input, decoder_input) mask_decoder = mask_decoder * self._make_history_mask(decoder_input) # Mask the labels decoder_output = numpy.array(decoder_output, dtype=numpy.int64) decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1) # Get the loss mask loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64) loss_mask[:length_toks_decoder] = 1 return { "text_enc": encoder_input, "text_dec": decoder_input, "labels": decoder_output, "loss_mask": loss_mask, "truncated": int(truncated), "enc_mask": mask_encoder, "dec_mask": mask_decoder, "enc_dec_mask": mask_encoder_decoder, } @staticmethod def _make_attention_mask( source_block: numpy.ndarray, target_block: numpy.ndarray ) -> numpy.ndarray: """Return a 2-D attention mask Args: source_block (numpy.ndarray): A 1-D array target_block (numpy.ndarray): A 1-D array Returns: numpy.ndarray: The 2-D attention mask """ mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1) return mask.astype(numpy.int64) @staticmethod def _make_history_mask(block: numpy.ndarray) -> numpy.ndarray: """Return a 2-D history (lower-left-triangular) mask Args: block (numpy.ndarray): A 1-D array Returns: numpy.ndarray: The 2-D history (lower-left-triangular) mask """ arange = numpy.arange(block.shape[0]) mask = arange[None,] <= arange[:, None] return mask.astype(numpy.int64) def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int: """Abstract method implementation 100% of the time, replace the token id with mask token id. Args: numpy_random_state (RandomState): The NumPy random state Returns: int: The mask token id """ return self.config.tokenizer.mask