base.py 6.68 KB
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Generator, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union

from megatron.energon.flavors.base_dataset import (
    FlexState,
    Sample,
    SavableDataset,
    add_sample_restore_key,
)
from megatron.energon.savable import Savable
from megatron.energon.worker import WorkerConfig

T = TypeVar("T")
T_sample = TypeVar("T_sample", covariant=True)
T_sample_out = TypeVar("T_sample_out", covariant=True)
T_sample_in = TypeVar("T_sample_in", covariant=True)


class BaseWrapperDataset(SavableDataset[T_sample_out], Generic[T_sample_in, T_sample_out], ABC):
    """Base class for dataset wrappers. All dataset wrappers should derive from this. A dataset
    wrapper takes one dataset and modifies its samples to make a new dataset. This can be for
    shuffling samples or applying custom functions to the data. Some wrappers only modify the
    length of the dataset or how it's repeated."""

    datasets: Tuple[SavableDataset[T_sample_in], ...]

    def __init__(
        self,
        datasets: Union[SavableDataset[T_sample_in], Iterable[SavableDataset[T_sample_in]]],
        *,
        worker_config: WorkerConfig,
    ):
        super().__init__(worker_config=worker_config)

        if isinstance(datasets, SavableDataset):
            self.datasets = (datasets,)
        else:
            self.datasets = tuple(datasets)

        for d in self.datasets:
            # Check that the dataset worker configs are the same as the wrapper worker config
            assert d.worker_config == self.worker_config, (
                "Dataset and wrapper worker configs must match."
            )

    @property
    def dataset(self) -> SavableDataset:
        """Convenience property, if only one dataset is wrapped."""

        assert len(self.datasets) == 1
        return self.datasets[0]

    def can_restore_sample(self) -> bool:
        return all(ds.can_restore_sample() for ds in self.datasets)

    def assert_can_restore(self) -> None:
        for ds in self.datasets:
            ds.assert_can_restore()

    def worker_has_samples(self) -> bool:
        return any(ds.worker_has_samples() for ds in self.datasets)

    def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDataset]:
        """Find the outermost dataset wrapped in this dataset that is of type cls."""

        for ds in self.datasets:
            if isinstance(ds, cls):
                return ds
            elif isinstance(ds, BaseWrapperDataset):
                res = ds._find_wrapped_dataset(cls)
                if res is not None:
                    return res
        return None

    def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out:
        if len(self.datasets) == 1:
            return self.datasets[0].restore_sample(restore_key)
        else:
            id, ds_idx = restore_key[:2]
            assert id == type(self).__name__
            restore_key = restore_key[2:]
            assert isinstance(ds_idx, int)
            return add_sample_restore_key(
                self.datasets[ds_idx].restore_sample(restore_key),
                ds_idx,
                src=self,
            )

    def save_state(self) -> FlexState:
        own_state = super().save_state()

        return FlexState(datasets=[ds.save_state() for ds in self.datasets], **own_state)

    def restore_state(self, state: FlexState) -> None:
        assert len(self.datasets) == len(state["datasets"])
        for dataset, dstate in zip(self.datasets, state["datasets"]):
            dataset.restore_state(dstate)

        super().restore_state(state)

    def reset_state_deep(self) -> None:
        """Resets the state of the inner datasets and then the own state."""

        for ds in self.datasets:
            if isinstance(ds, BaseWrapperDataset):
                ds.reset_state_deep()
            else:
                ds.reset_state_own()

        self.reset_state_own()

    @abstractmethod
    def reset_state_own(self) -> None:
        """Resets the state of the dataset, excl. the inner datasets."""
        ...


class SampleIndex(Savable):
    """A simple class to hold the sample index for one worker."""

    worker_config: WorkerConfig
    current_idx: int

    actives = 0

    def __init__(self, worker_config: WorkerConfig, *, src: Any) -> None:
        self.worker_config = worker_config
        self.current_idx = 0
        self.src = src

    def get_next(self) -> int:
        res = self.current_idx
        self.current_idx += 1
        return res

    @contextmanager
    def ctx(self, sample_idx: Optional[int] = None):
        if sample_idx is None:
            sample_idx = self.get_next()
        assert WorkerConfig.active_worker_config is not None
        WorkerConfig.active_worker_config.worker_push_sample_index(sample_idx)
        # print("  " * SampleIndex.actives + f"Activated from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
        SampleIndex.actives += 1
        try:
            yield sample_idx
        finally:
            assert WorkerConfig.active_worker_config is not None
            popped = WorkerConfig.active_worker_config.worker_pop_sample_index()
            SampleIndex.actives -= 1
            # print("  " * SampleIndex.actives + f"Deactivate from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}")
            assert popped == sample_idx, f"Expected {sample_idx}, got {popped}"

    def iter_ctx(
        self,
        it: Iterable[T_sample],
        sample_idx: Optional[int] = None,
    ) -> Generator[Tuple[int, T_sample], None, None]:
        it = iter(it)
        try:
            while True:
                try:
                    with self.ctx(sample_idx) as res_sample_idx:
                        x = next(it)
                    yield res_sample_idx, x
                except StopIteration:
                    break
        finally:
            if hasattr(it, "close"):
                it.close()

    def save_state(self) -> int:
        return self.current_idx

    def restore_state(self, state: Optional[int]) -> None:
        if state is None:
            self.current_idx = 0
        else:
            self.current_idx = state


def get_sample_restore_key(sample: Any) -> Optional[Union[str, int]]:
    """Gets the restore key from an arbitrary sample."""
    if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"):
        return sample.__restore_key__
    elif isinstance(sample, dict) and "__restore_key__" in sample:
        return sample["__restore_key__"]
    else:
        return None