helpers.py 2.08 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import numpy

# Implicit imports for backwards compatibility
# Explicit imports for readability
from megatron.core.datasets.helpers_cpp import *
from megatron.core.datasets.helpers_cpp import build_sample_idx_int32, build_sample_idx_int64


def build_sample_idx(
    sizes: numpy.ndarray,
    document_indices: numpy.ndarray,
    sequence_length: int,
    num_epochs: int,
    tokens_per_epoch: int,
    drop_last_partial_sequence: bool = True,
    add_extra_token_to_sequence: bool = True,
):
    """Build the 2-D sample index using the properly typed templated C++ function from helpers.cpp

    Args:
        sizes (numpy.ndarray): The 1-D array of document lengths

        document_indices (numpy.ndarray): The 1-D array of document indices

        sequence_length (int): The sequence length

        num_epochs (int): The number of epochs

        tokens_per_epoch (int): The number of tokens per epoch

        drop_last_partial_sequence (bool): Whether to omit the last partial sequence in the sample
            index should it exist. Defaults to True.

        add_extra_token_to_sequence (bool): Whether to build samples with sequence length
            `sequence_length + 1`. Defaults to True.

    Returns:
        numpy.ndarray: The 2-D sample index
    """
    sample_idx_max = max(document_indices.shape[0], sizes.max())
    if sample_idx_max <= numpy.iinfo(numpy.int32).max:
        sample_idx = build_sample_idx_int32(
            sizes,
            document_indices,
            sequence_length,
            num_epochs,
            tokens_per_epoch,
            drop_last_partial_sequence,
            1 if add_extra_token_to_sequence else 0,
        )
        assert sample_idx.min() >= 0 and sample_idx.max() <= sample_idx_max
    else:
        sample_idx = build_sample_idx_int64(
            sizes,
            document_indices,
            sequence_length,
            num_epochs,
            tokens_per_epoch,
            drop_last_partial_sequence,
            1 if add_extra_token_to_sequence else 0,
        )
    return sample_idx