t5_dataset.py 8.1 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

import numpy

from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.masked_dataset import (
    MaskedWordPieceDataset,
    MaskedWordPieceDatasetConfig,
)
from megatron.core.datasets.utils import Split


@dataclass
class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
    """Configuration object for Megatron Core T5 WordPiece datasets

    NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines
    a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to
    preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
    """

    sequence_length_encoder: Optional[int] = field(init=False, default=None)
    """A sequence_length alias and the sequence length for the encoder"""

    sequence_length_decoder: int = None
    """The sequence length for the decoder"""

    def __post_init__(self) -> None:
        """Do asserts and set fields post init
        """
        super().__post_init__()

        self.sequence_length_encoder = self.sequence_length

        assert self.sequence_length_encoder is not None
        assert self.sequence_length_decoder is not None

        assert len(self.tokenizer.additional_special_tokens_ids) > 0


class T5MaskedWordPieceDataset(MaskedWordPieceDataset):
    """The T5 dataset that assumes WordPiece tokenization

    Args:
        indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset

        dataset_path (str): The real path on disk to the dataset, for bookkeeping

        indexed_indices (numpy.ndarray): The set of the documents indices to expose

        num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.

        index_split (Split): The indexed_indices Split

        config (T5MaskedWordPieceDatasetConfig): The config
    """

    def __init__(
        self,
        indexed_dataset: IndexedDataset,
        dataset_path: str,
        indexed_indices: numpy.ndarray,
        num_samples: Optional[int],
        index_split: Split,
        config: T5MaskedWordPieceDatasetConfig,
    ) -> None:
        super().__init__(
            indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
        )

        self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
        # Account for the single <bos> and single <eos> token ids
        self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1)

    @staticmethod
    def _key_config_attributes() -> List[str]:
        """Inherited method implementation

        Returns:
            List[str]: The key config attributes
        """
        return super(
            T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
        )._key_config_attributes() + ["sequence_length_decoder",]

    def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
        """Abstract method implementation
 
        Args:
            idx (int): The index into the dataset

        Returns:
            Dict[str, Union[int, numpy.ndarray]]: The 
        """
        idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
        sample = [self.dataset[i] for i in range(idx_beg, idx_end)]

        numpy_random_state = numpy.random.RandomState(
            seed=(self.config.random_seed + idx) % 2 ** 32
        )

        assert target_sequence_length <= self.config.sequence_length

        # Flatten the sample into a list of tokens
        tokens = [token for sentence in sample for token in sentence]

        # Truncate the list of tokens to a desired length
        truncated = len(tokens) > target_sequence_length
        tokens = tokens[:target_sequence_length]

        # Masking
        (tokens, _, _, _, masked_spans,) = self._create_masked_lm_predictions(
            tokens, target_sequence_length, numpy_random_state
        )

        # Prepare the encoder input and decoder input and output
        sentinels = deque(self.config.tokenizer.additional_special_tokens_ids)
        encoder_input = []
        decoder_input = [self.config.tokenizer.bos]
        decoder_output = []
        idx_beg = 0
        for indices, labels in masked_spans:
            sentinel = sentinels.popleft()

            # set the end index
            idx_end = indices[0]

            encoder_input.extend(tokens[idx_beg:idx_end])
            encoder_input.append(sentinel)

            decoder_input.append(sentinel)
            decoder_input.extend(labels)

            decoder_output.append(sentinel)
            decoder_output.extend(labels)

            # set the start index
            idx_beg = indices[-1] + 1

        encoder_input.extend(tokens[idx_beg:])
        decoder_output.append(self.config.tokenizer.eos)

        # Pad the sequences and convert to NumPy
        length_toks_encoder = len(encoder_input)
        length_toks_decoder = len(decoder_input)
        length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder
        length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder
        assert length_pads_encoder >= 0
        assert length_pads_decoder >= 0

        encoder_input = numpy.array(encoder_input, dtype=numpy.int64)
        encoder_input = numpy.pad(
            encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad
        )

        decoder_input = numpy.array(decoder_input, dtype=numpy.int64)
        decoder_input = numpy.pad(
            decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad
        )

        # Create attention and history masks
        mask_encoder = self._make_attention_mask(encoder_input, encoder_input)
        mask_encoder_decoder = self._make_attention_mask(decoder_input, encoder_input)
        mask_decoder = self._make_attention_mask(decoder_input, decoder_input)
        mask_decoder = mask_decoder * self._make_history_mask(decoder_input)

        # Mask the labels
        decoder_output = numpy.array(decoder_output, dtype=numpy.int64)
        decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1)

        # Get the loss mask
        loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64)
        loss_mask[:length_toks_decoder] = 1

        return {
            "text_enc": encoder_input,
            "text_dec": decoder_input,
            "labels": decoder_output,
            "loss_mask": loss_mask,
            "truncated": int(truncated),
            "enc_mask": mask_encoder,
            "dec_mask": mask_decoder,
            "enc_dec_mask": mask_encoder_decoder,
        }

    @staticmethod
    def _make_attention_mask(
        source_block: numpy.ndarray, target_block: numpy.ndarray
    ) -> numpy.ndarray:
        """Return a 2-D attention mask

        Args:
            source_block (numpy.ndarray): A 1-D array
            target_block (numpy.ndarray): A 1-D array

        Returns:
            numpy.ndarray: The 2-D attention mask
        """
        mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
        return mask.astype(numpy.int64)

    @staticmethod
    def _make_history_mask(block: numpy.ndarray) -> numpy.ndarray:
        """Return a 2-D history (lower-left-triangular) mask

        Args:
            block (numpy.ndarray): A 1-D array

        Returns:
            numpy.ndarray: The 2-D history (lower-left-triangular) mask
        """
        arange = numpy.arange(block.shape[0])
        mask = arange[None,] <= arange[:, None]
        return mask.astype(numpy.int64)

    def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int:
        """Abstract method implementation

        100% of the time, replace the token id with mask token id.

        Args:
            numpy_random_state (RandomState): The NumPy random state

        Returns:
            int: The mask token id
        """
        return self.config.tokenizer.mask