epochize_dataset.py 4.28 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Dict, Generic, Iterator, Optional, TypeVar

from megatron.energon.flavors.base_dataset import SavableDataset
from megatron.energon.worker import WorkerConfig
from megatron.energon.wrappers.base import BaseWrapperDataset

T_sample = TypeVar("T_sample")


class EpochizeDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]):
    """
    Uses the base dataset, and creates one epoch, which has length samples. Keeps the underlying
    dataset iterator alive over epochs (i.e. if it is an infinite dataset, it will keep the state).
    Repeats the underlying dataset if the iterator is exhausted.
    """

    length: int
    _active_iter: Optional[Iterator[T_sample]]
    _offset: int

    _savable_fields = ("_offset",)

    def __init__(
        self,
        dataset: SavableDataset[T_sample],
        length: int,
        worker_config: WorkerConfig,
    ):
        """
        Create the epochized dataset.

        Args:
            dataset: The source dataset (possibly infinite)
            length: Number of samples to iterate before iteration stops (i.e. one epoch). When
                iteration continues, the original dataset iterator is resumed and does only restart
                if exhausted.
            worker_config: Configuration for the workers.
        """
        super().__init__(dataset, worker_config=worker_config)
        self.length = length
        self._active_iter = None

        self.reset_state_own()

    def reset_state_own(self) -> None:
        self._offset = 0

    def __iter__(self) -> Iterator[T_sample]:
        # Compute the local length for this worker, i.e. all worker's lengths sum up to the total

        if self.worker_config.num_workers <= 1:
            local_length = self.length
        else:
            local_length = self.length // self.worker_config.num_workers
            if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers:
                local_length += 1

        if self.worker_config.should_log(level=2):
            self.worker_config.worker_log(
                {
                    "t": "EpochizeDataset.epoch_start",
                    "r": self.worker_config.rank,
                    "w": self.worker_config.rank_worker_id(),
                    "offset": self._offset,
                    "local_length": local_length,
                    "length": self.length,
                }
            )

        offset_range = list(range(self._offset, local_length))

        # Only iterate if there are samples to iterate
        if len(offset_range) > 0:
            if self._active_iter is None:
                self._active_iter = iter(self.dataset)

            for idx in offset_range:
                self._offset = (idx + 1) % local_length
                try:
                    sample = next(self._active_iter)
                except StopIteration:
                    break
                yield sample

        if self.worker_config.should_log(level=2):
            self.worker_config.worker_log(
                {
                    "t": "EpochizeDataset.epoch_end",
                    "r": self.worker_config.rank,
                    "w": self.worker_config.rank_worker_id(),
                    "offset": self._offset,
                    "local_length": local_length,
                    "length": self.length,
                }
            )

    def len_worker(self, worker_idx: int | None = None) -> int:
        if worker_idx is None:
            self.worker_config.assert_worker()
            worker_idx = self.worker_config.rank_worker_id()
        if self.worker_config.num_workers <= 1:
            assert worker_idx == 0
            return self.length
        else:
            local_length = self.length // self.worker_config.num_workers
            if worker_idx < self.length % self.worker_config.num_workers:
                local_length += 1
            return local_length

    def config(self) -> Dict[str, Any]:
        return {
            "type": type(self).__qualname__,
            "dataset": self.dataset.config(),
            "length": self.length,
            "worker_config": self.worker_config.config(),
        }

    def __str__(self):
        return f"EpochizeDataset(length={self.length}, dataset={self.dataset})"