utils.py 12 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.

"""Utilities for Retro preprocessing."""

import glob
import logging
import os
from collections import defaultdict
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch
from tqdm import tqdm

from megatron.core import parallel_state
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
    MultiSplitGPTDataset,
    MultiSplitGPTDatasetConfig,
)
from megatron.core.utils import log_single_rank

from .external_libs import h5py

logger = logging.getLogger(__name__)


def log_retro_rank_0(message: str) -> None:
    """Log on rank 0.

    Args:
        message (str): Message to log.
    """
    log_single_rank(logger, logging.INFO, "[RETRO] " + message)


def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None:
    """Make a directory, conditional on not being in validation mode.

    Args:
        config (RetroPreprocessingConfig): Retro preprocessing config.
        path (str): Path to directory.
    """
    if config.retro_task_validate is None:
        os.makedirs(path, exist_ok=True)


def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig:
    """Extract data config from dataset.

    Args:
        config (RetroPreprocessingConfig): Retro preprocessing config.

    Returns:
        The config object used to build the dataset.
    """
    return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config


def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int:
    """Compute seq_length // chunk_length.

    Args:
        sample_length (int): Alias of `sequence_length`.
        chunk_length (int): Retro chunk length (e.g., 64).

    Returns:
        Number of chunks per sample (i.e., `sequence_length` / `chunk_length`).
    """
    assert sample_length % chunk_length == 0
    return sample_length // chunk_length


class GPTToTextDataset(torch.utils.data.Dataset):
    """Dataset to convert GPT tokens to text.

    Args:
        gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples.
        gpt_tokenizer (Any): GPT tokenizer.
    """

    def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any):

        super().__init__()

        self.gpt_dataset = gpt_dataset
        self.gpt_tokenizer = gpt_tokenizer

    def __len__(self) -> int:
        """Dataset length.

        Returns:
            Number of samples in the dataset.
        """
        return len(self.gpt_dataset)

    def __getitem__(self, idx: int) -> dict:
        """Get dataset sample.

        Args:
            idx (int): Index of sample.

        Returns:
            A dict containing attribute 'text' of type string.
        """
        gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
        text = self.gpt_tokenizer.detokenize(gpt_token_ids)
        return {"text": text}


def get_blocks(
    dirname: str, n_samples: int, block_size: int, validate: Callable = None
) -> SimpleNamespace:
    """Divide range [0, num_samples) to sequence of block ranges.

    This is a core method within the concept of block processing. The idea
    is to divide a range (size n_samples) into a sequence of blocks. Each
    block corresponds to a file within 'dirname' with name
    '{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
    these files, and returns two lists, one for existing blocks and one for
    missing blocks.

    Args:
        dirname (str): Path to directory containing block files.
        n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
        block_size (int): Max number of samples per block file (e.g., 100000).
        validate (Callable): Method for validating each block file during load.

    Returns:
        A namespace consisting of 2 lists: existing blocks, and missing blocks. The total number of samples between the existing and missing blocks should equal n_samples above.
    """

    assert os.path.isdir(dirname), "missing directory '%s.'" % dirname

    # Block ranges.
    block_start_idxs = list(range(0, n_samples, block_size))
    block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs]
    block_ranges = list(zip(block_start_idxs, block_end_idxs))

    # All block files (existing + missing).
    n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
    all_blocks = [
        {
            "range": r,
            "path": os.path.join(
                dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r])
            ),
        }
        for r in block_ranges
    ]
    all_block_path_set = set(block["path"] for block in all_blocks)

    # Validate function.
    validate = (lambda f: None) if validate is None else validate

    # Delete corrupt files.
    if torch.distributed.get_rank() == 0:
        existing_block_paths = [
            block["path"] for block in all_blocks if os.path.exists(block["path"])
        ]
        for index, path in enumerate(tqdm(existing_block_paths, "validating block.")):

            assert path in all_block_path_set, "unexpected filename, '%s'." % path

            try:
                f = h5py.File(path, "r")
            except Exception:
                os.remove(path)
                continue

            try:
                validate(f)
            except Exception:
                os.remove(path)
            finally:
                f.close()

    # Wait for files to be deleted.
    torch.distributed.barrier()

    # Collect blocks.
    blocks = SimpleNamespace(
        existing=[b for b in all_blocks if os.path.exists(b["path"])],
        missing=[b for b in all_blocks if not os.path.exists(b["path"])],
    )

    return blocks


def get_blocks_by_rank(
    dirname: str,
    n_samples: int,
    block_size: int,
    validate: Callable = None,
    sample: Optional[float] = None,
) -> SimpleNamespace:
    """Divide existing and missing blocks evenly across all ranks.

    See 'get_blocks()' above for description. The returned lists of existing and
    missing blocks are split evenly across ranks via interleaving. This way,
    each rank has a roughly equal number of blocks to process for a
    downstream operation.

    Args:
        dirname (str): Path to directory containing block files.
        n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
        block_size (int): Max number of samples per block file (e.g., 100000).
        validate (Callable): Method for validating each block file during load.
        sample (Optional[float]): If provided, sample a random subset of the blocks. Used for validating preprocessing correctness.

    Returns:
        A namespace consisting of 2 lists: existing blocks, and missing blocks. Each of these two lists is potentially a sub-sample of the total set of existing and missing blocks, depending on whether sampling is used. Additionally, the attributes n_existing_world and n_missing_world are the total number of existing and missing blocks, independent of samples. Therefore, (n_existing_world + n_missing_world) * block_size == n_samples.
    """

    # Get world blocks.
    blocks = get_blocks(dirname, n_samples, block_size, validate)

    # This rank's existing and missing files.
    data_parallel_rank = parallel_state.get_data_parallel_rank()
    data_parallel_world_size = parallel_state.get_data_parallel_world_size()
    rank_existing_blocks = blocks.existing[
        data_parallel_rank : len(blocks.existing) : data_parallel_world_size
    ]
    rank_missing_blocks = blocks.missing[
        data_parallel_rank : len(blocks.missing) : data_parallel_world_size
    ]

    # Extend rank's existing and missing blocks (with None) such that all ranks
    # have equal length lists. This allows for easier tracking of global progress.
    def get_world_max(n: int) -> int:
        """Get max value across ranks.

        Args:
            n (int): Value on this rank.

        Returns:
            Max value across all ranks.
        """
        n_tensor = torch.cuda.LongTensor([n])
        torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX)
        return n_tensor.item()

    max_n_existing = get_world_max(len(rank_existing_blocks))
    max_n_missing = get_world_max(len(rank_missing_blocks))

    rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks))
    rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))

    # Collect blocks.
    blocks = SimpleNamespace(
        n_existing_world=len(blocks.existing),
        n_missing_world=len(blocks.missing),
        existing=rank_existing_blocks,
        missing=rank_missing_blocks,
    )

    if sample is not None:
        # Sample existing and missing blocks evenly across all ranks. The
        # returned lists of blocks are randomly sampled (without replacement)
        # to yield `sample * len(blocks)` number of blocks.

        # Randomly sample blocks.
        def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]:
            """Sample a random subset of all blocks.

            Args:
                _blocks (List[Optional[Dict]]): List of all blocks.

            Returns:
                A random subset of the blocks.
            """
            n_blocks_sample = int(np.ceil(sample * len(_blocks)))
            sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None]

            np.random.seed(None)
            np.random.shuffle(sampled_blocks)

            sampled_blocks = sampled_blocks[:n_blocks_sample]
            sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks))

            return sampled_blocks

        blocks.existing = sample_blocks(blocks.existing)
        blocks.missing = sample_blocks(blocks.missing)

    return blocks


class BlockPathMap:
    """Map an index to its containing block path.

    The common use for this class is to have a directory of files containing
    blocks of processed data, of uniform block size (e.g., 100k samples per
    file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
    where 'endIdx' minus 'startIdx' must equal the block size, with the possible
    exception of the final block. Given an input index, this class maps the
    index to the containing block file.

    Args:
        block_paths (List[str]): List of paths to saved block files.
        block_size (int): Max number of samples per block file (e.g., 100000).
    """

    @classmethod
    def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any:
        """Get list of block files, and create map.

        Args:
            dir (str): Path to directory containing saved block files.
            block_size (int): Max number of samples per block file (e.g., 100000).
            ext (str): Block file extension (e.g., 'hdf5').

        Returns:
            A mapping of sample index to block file path.
        """
        assert os.path.isdir(dir), f"directory not found, '{dir}'."
        return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size)

    def __init__(self, block_paths: List[str], block_size: int):
        self.max_idx = 0
        self.block_path_map = {}
        for block_path in block_paths:
            name = os.path.splitext(os.path.basename(block_path))[0]
            start_idx, end_idx = [int(i) for i in name.split("-")]
            self.block_path_map[start_idx] = block_path
            self.max_idx = max(self.max_idx, end_idx)
        self.block_size = block_size

    def __str__(self) -> str:
        """Stringify the mapping.

        Returns:
            A string representation of this block path map.
        """
        return "%d paths" % len(self.block_path_map)

    def __getitem__(self, idx: int) -> str:
        """Get block path from index.

        Args:
            idx (int): Index of sample.

        Returns:
            The path to the block file containing the sample index.
        """
        block_start_idx = self.block_size * (idx // self.block_size)
        block_path = self.block_path_map[block_start_idx]
        return block_path