# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import logging import os import time from abc import abstractmethod from dataclasses import dataclass from typing import List, Optional, Tuple import numpy import torch from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig from megatron.core.datasets.indexed_dataset import IndexedDataset from megatron.core.datasets.megatron_dataset import MegatronDataset from megatron.core.datasets.utils import Split from megatron.core.utils import log_single_rank logger = logging.getLogger(__name__) @dataclass class MaskedWordPieceDatasetConfig(BlendedMegatronDatasetConfig): """Configuration object for Megatron Core Masked WordPiece datasets""" masking_probability: float = None """The probability we mask a candidate N-gram""" short_sequence_probability: float = None """The probability we return a sequence shorter than the target sequence length""" masking_max_ngram: int = None """The maximum length N-gram to consider masking or permuting""" masking_do_full_word: bool = None """Whether we mask the whole word or its component parts""" masking_do_permutation: bool = None """Whether we shuffle a subset of candidate N-grams in addition""" masking_use_longer_ngrams: bool = None """Whether to favor longer N-grams over shorter N-grams""" masking_use_geometric_distribution: bool = None """Whether to draw the size of the N-gram from a geometric distribution according to SpanBERT https://arxiv.org/abs/1907.10529 (Section 3.1) """ def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() assert self.tokenizer is not None assert self.masking_probability is not None assert self.short_sequence_probability is not None assert self.masking_max_ngram is not None assert self.masking_do_full_word is not None assert self.masking_do_permutation is not None assert self.masking_use_longer_ngrams is not None assert self.masking_use_geometric_distribution is not None assert self.masking_probability > 0 and self.masking_probability < 1.0 assert self.short_sequence_probability >= 0 and self.short_sequence_probability <= 1.0 assert self.masking_max_ngram > 0 assert not (self.masking_use_geometric_distribution and self.masking_do_permutation) if self.masking_use_geometric_distribution and self.masking_use_longer_ngrams: log_single_rank( logger, logging.WARNING, "The use of a geometric distribution overrides the default distribution", ) class MaskedWordPieceDataset(MegatronDataset): """The semi-abstract base class for masked WordPiece datasets This implementation makes the rigid assumption that all inheritor datasets are built upon the IndexedDataset class. This assumption may be pushed down to the inheritors in future if necessary. NB: WordPiece tokenization prepends a double hash "##" to all tokens/pieces in a word, save the first token/piece. Args: indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset dataset_path (str): The real path on disk to the dataset, for bookkeeping indexed_indices (numpy.ndarray): The set of the documents indices to expose num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch. index_split (Split): The indexed_indices Split config (MaskedWordPieceDatasetConfig): The config """ def __init__( self, indexed_dataset: IndexedDataset, dataset_path: str, indexed_indices: numpy.ndarray, num_samples: Optional[int], index_split: Split, config: MaskedWordPieceDatasetConfig, ) -> None: super().__init__( indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config ) @staticmethod def numel_low_level_dataset(low_level_dataset: IndexedDataset) -> int: return low_level_dataset.document_indices.shape[0] - 1 @staticmethod def build_low_level_dataset( dataset_path: str, config: MaskedWordPieceDatasetConfig ) -> IndexedDataset: return IndexedDataset(dataset_path) @staticmethod def _key_config_attributes() -> List[str]: """Inherited method implementation Returns: List[str]: The key config attributes """ return super(MaskedWordPieceDataset, MaskedWordPieceDataset)._key_config_attributes() + [ "masking_probability", "short_sequence_probability", "masking_max_ngram", "masking_do_full_word", "masking_do_permutation", "masking_use_longer_ngrams", "masking_use_geometric_distribution", ] def __len__(self) -> int: return self.sample_index.shape[0] def _build_sample_index( self, sequence_length: int, min_sentences_per_sample: int ) -> numpy.ndarray: path_to_cache = self.config.path_to_cache if path_to_cache is None: path_to_cache = os.path.join( self.dataset.path_prefix, "cache", f"{type(self).__name__}_indices" ) get_path_to = lambda suffix: os.path.join( path_to_cache, f"{self.unique_description_hash}-{type(self).__name__}-{suffix}" ) path_to_description = get_path_to("description.txt") path_to_sample_index = get_path_to("sample_index.npy") cache_hit = all(map(os.path.isfile, [path_to_description, path_to_sample_index])) if self.num_samples is not None: num_epochs = numpy.iinfo(numpy.int32).max - 1 else: num_epochs = 1 if not cache_hit and torch.distributed.get_rank() == 0: log_single_rank( logger, logging.INFO, f"Build and save the {type(self).__name__} {self.index_split.name} indices", ) self.built_anew_on_cache_miss = True os.makedirs(path_to_cache, exist_ok=True) # Write the description with open(path_to_description, "wt") as writer: writer.write(self.unique_description) # Build the sample index log_single_rank( logger, logging.INFO, f"\tBuild and save the sample index to {os.path.basename(path_to_sample_index)}", ) t_beg = time.time() from megatron.core.datasets import helpers # Add +1 for access to document upper bound indices = numpy.append(self.indices, self.indices[-1] + 1) sample_index = helpers.build_mapping( self.dataset.document_indices[indices], self.dataset.sequence_lengths, num_epochs, self.num_samples, sequence_length, self.config.short_sequence_probability, self.config.random_seed, False, min_sentences_per_sample, ) numpy.save(path_to_sample_index, sample_index, allow_pickle=True) t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") log_single_rank( logger, logging.INFO, f"> total number of samples: {sample_index.shape[0]}" ) log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}") return sample_index log_single_rank( logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices" ) log_single_rank( logger, logging.INFO, f"\tLoad the sample index from {os.path.basename(path_to_sample_index)}", ) t_beg = time.time() sample_index = numpy.load(path_to_sample_index, allow_pickle=True, mmap_mode="r") t_end = time.time() log_single_rank(logger, logging.DEBUG, f"\t> time elapsed: {t_end - t_beg:4f} seconds") return sample_index def _create_masked_lm_predictions( self, token_ids: List[int], target_sequence_length: int, numpy_random_state: numpy.random.RandomState, ) -> Tuple[List[int], List[int], List[int], List[int], List[Tuple[List[int], List[int]]]]: """Creates the predictions for the masked LM objective Args: token_ids (List[int]): The token ids target_sequence_length (int): The target sequence length numpy_random_state (numpy.random.RandomState): The NumPy random state Returns: Tuple[List[int], List[int], List[int], List[int], List[Tuple[List[int], List[int]]]]: 1. masked_token_ids -> The masked sequence 2. masked_positions -> The indices for the masked token ids 3. masked_labels -> The original token ids for the masked token ids 4. boundaries -> The sentence and word boundaries for the sequence 4. masked_spans -> The masked positions and labels with N-gram info intact """ # Build the token sentence and word boundaries and the masking candidates # e.g. [cls, id, ##id, ##id, id, ##id, sep, id, ##id, sep] # -> boundaries: [1, 1, 0, 0, 1, 0, 1, 1, 0, 1] # -> candidates with whole word masking: [[1, 2, 3], [4, 5], [7, 8]] # -> candidates sans whole word masking: [[1], [2], [3], [4], [5], [7], [8]] boundaries = [] candidates = [] for i, token_id in enumerate(token_ids): if token_id == self.config.tokenizer.cls or token_id == self.config.tokenizer.sep: boundaries.append(1) else: if not self.config.tokenizer.inv_vocab[token_id].startswith("##"): boundaries.append(1) candidates.append([i]) else: boundaries.append(0) if self.config.masking_do_full_word and len(candidates) > 0: candidates[-1].append(i) else: candidates.append([i]) n_maskings = min( self.config.masking_probability * target_sequence_length, max(1, int(round(len(token_ids) * self.config.masking_probability))), ) ngram_nvals = numpy.arange(self.config.masking_max_ngram, dtype=numpy.int64) + 1 # By default, the N-gram probabilities are inversely proportional to N # e.g. N = 3 # -> P = array([0.54545455, 0.27272727, 0.18181818]) nprobs = 1.0 / ngram_nvals nprobs = nprobs / nprobs.sum(keepdims=True) if self.config.masking_use_longer_ngrams: nprobs = nprobs[::-1] # Create a nested list of depth 3 # layer 1: the candidate dimension # layer 2: the N-gram dimension # layer 3: the token dimension candidate_ngrams = [ [candidates[idx : idx + n] for n in ngram_nvals] for idx in range(len(candidates)) ] numpy_random_state.shuffle(candidate_ngrams) masked_token_ids = list(token_ids) masked_positions_and_labels = [] masked_spans = [] masked_indices = set() for candidate_idx in range(len(candidate_ngrams)): n_ngrams = len(candidate_ngrams[candidate_idx]) # Stop when we hit our desired number of maskings if len(masked_positions_and_labels) >= n_maskings: break # Do nothing for candidates with no ngrams if not candidate_ngrams[candidate_idx]: continue # Choose the initial value of N if self.config.masking_use_geometric_distribution: # Sample N from a geometric distribution with p = 0.2 and clip # i.e. SpanBERT # -> https://arxiv.org/abs/1907.10529 (Section 3.1) p = 0.2 n = min(numpy_random_state.geometric(p), self.config.masking_max_ngram) else: p = nprobs[:n_ngrams] / nprobs[:n_ngrams].sum(keepdims=True) n = numpy_random_state.choice(ngram_nvals[:n_ngrams], p=p) while True: ngram_indices = sum(candidate_ngrams[candidate_idx][n - 1], []) n = n - 1 # Success: masking this N-gram puts us below the desired number of maskings if n_maskings >= len(masked_positions_and_labels) + len(ngram_indices): skip_candidate = False break # Failure: no N-grams remain for this candidate if n == 0: skip_candidate = True break # Do nothing for candidates whose 1-gram is too long if skip_candidate: continue # Do nothing for candidate indices which have already been masked if any(map(lambda idx: idx in masked_indices, ngram_indices)): continue # Mask the tokens and record their original positions and values for index in ngram_indices: masked_indices.add(index) mask = self._get_token_mask(numpy_random_state) if mask is None: masked_token_ids[index] = token_ids[index] else: masked_token_ids[index] = mask masked_positions_and_labels.append((index, token_ids[index])) masked_spans.append((ngram_indices, [token_ids[index] for index in ngram_indices])) assert len(masked_positions_and_labels) <= n_maskings numpy_random_state.shuffle(candidate_ngrams) if self.config.masking_do_permutation: n_swappings = n_maskings permuted_indices = set() for candidate_idx in range(len(candidate_ngrams)): n_ngrams = len(candidate_ngrams[candidate_idx]) if len(permuted_indices) >= n_swappings: break # Do nothing for candidates with no ngrams if not candidate_ngrams[candidate_idx]: continue p = nprobs[:n_ngrams] / nprobs[:n_ngrams].sum(keepdims=True) n = numpy.random.choice(ngram_nvals[:n_ngrams], p=p) while True: ngram_indices = sum(candidate_ngrams[candidate_idx][n - 1], []) n = n - 1 # Success: swapping this N-gram puts us below the desired number of swappings if n_swappings >= len(permuted_indices) + len(ngram_indices): skip_candidate = False break # Failure: no N-grams remain for this candidate if n == 0: skip_candidate = True break # Do nothing for candidates whose 1-gram is too long if skip_candidate: continue # Do nothing for candidate indices which have already been masked or permuted if any( map(lambda idx: idx in masked_indices or idx in permuted_indices, ngram_indices) ): continue for index in ngram_indices: permuted_indices.add(index) assert len(permuted_indices) <= n_swappings permuted_indices = sorted(permuted_indices) permuted_indices_copy = list(permuted_indices) numpy_random_state.shuffle(permuted_indices_copy) masked_token_ids_copy = list(masked_token_ids) for idx, idx_copy in zip(permuted_indices, permuted_indices_copy): masked_token_ids[idx] = masked_token_ids_copy[idx_copy] masked_positions_and_labels.append((idx, masked_token_ids_copy[idx])) masked_positions_and_labels = sorted(masked_positions_and_labels, key=lambda x: x[0]) masked_positions = [] masked_labels = [] for position, label in masked_positions_and_labels: masked_positions.append(position) masked_labels.append(label) masked_spans = sorted(masked_spans, key=lambda x: x[0][0]) return masked_token_ids, masked_positions, masked_labels, boundaries, masked_spans @abstractmethod def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> Optional[int]: pass