# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. ## # Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import ## import random import numpy import pytest import torch from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset from megatron.core.datasets.utils import compile_helpers from megatron.training.tokenizer.tokenizer import _NullTokenizer from tests.unit_tests.test_utilities import Utils _MOCK_VOCAB_SIZE = 8192 def sample_N(dataset, N, randomize): if randomize: indices = [random.randint(0, len(dataset) - 1) for _ in range(N)] else: indices = list(range(N)) samples = [dataset[index]["tokens"].numpy() for index in indices] return samples def test_mock_gpt_dataset(): if torch.distributed.is_available(): Utils.initialize_distributed() if torch.distributed.get_rank() == 0: compile_helpers() torch.distributed.barrier() else: compile_helpers() tokenizer = _NullTokenizer(vocab_size=_MOCK_VOCAB_SIZE) config = GPTDatasetConfig( random_seed=1234, sequence_length=1024, split="990,9,1", reset_position_ids=True, reset_attention_mask=True, eod_mask_loss=True, tokenizer=tokenizer, ) datasets = BlendedMegatronDatasetBuilder( MockGPTDataset, [100, 100, 100], lambda: True, config ).build() N = 10 # Check iso-index variance by split subsets = [sample_N(dataset, N, randomize=False) for dataset in datasets] assert not numpy.allclose(subsets[0], subsets[1]) assert not numpy.allclose(subsets[0], subsets[2]) assert not numpy.allclose(subsets[1], subsets[2]) # Check iso-split / iso-index identity subset_1A = sample_N(datasets[0], N, randomize=False) subset_1B = sample_N(datasets[0], N, randomize=False) assert numpy.allclose(subset_1A, subset_1B) # Check iso-split variance by index subset_1A = sample_N(datasets[0], N, randomize=True) subset_1B = sample_N(datasets[0], N, randomize=True) assert not numpy.allclose(subset_1A, subset_1B) config = GPTDatasetConfig( random_seed=1234, sequence_length=1024, split="990,10,0", reset_position_ids=True, reset_attention_mask=True, eod_mask_loss=True, drop_last_partial_validation_sequence=False, add_extra_token_to_sequence=False, tokenizer=tokenizer, ) datasets = BlendedMegatronDatasetBuilder( MockGPTDataset, [0, None, 0], lambda: True, config ).build() sample = datasets[1][datasets[1].shuffle_index.argmax()] argmax = sample['labels'].shape[0] - torch.flip(sample['labels'], [0]).argmax() - 1 # Test add_extra_token_to_sequence assert sample['tokens'][argmax] != tokenizer.eod assert sample['labels'][argmax] == tokenizer.eod # Test eod_mask_loss, drop_last_partial_validation_sequence assert argmax < sample['labels'].shape[0] - 1 assert torch.all(sample['labels'][argmax + 1 :] == 0) assert not torch.any( sample['loss_mask'][ torch.logical_and(sample['labels'] == tokenizer.eod, sample['labels'] == 0) ] ) sample = datasets[1][None] # Check handling of None index assert not torch.any(sample['loss_mask']) if __name__ == "__main__": test_mock_gpt_dataset()