orchestrator.py 7.2 KB
Newer Older
1
2
from __future__ import annotations

3
import abc
4
5
import weakref
from typing import TYPE_CHECKING, Optional, Set, Type
6
7
8

import torch

9
10
if TYPE_CHECKING:
    from sglang.srt.managers.schedule_batch import ScheduleBatch
11
12
13
14
15
16


class BatchedPenalizerOrchestrator:
    def __init__(
        self,
        vocab_size: int,
17
18
        batch: ScheduleBatch,
        penalizers: Set[Type["_BatchedPenalizer"]],
19
20
    ):
        self.vocab_size = vocab_size
21
        self._batch_ref = weakref.ref(batch)
22
23
        self.device = batch.device
        self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
24

25
        is_required = False
26
        for penalizer in self.penalizers.values():
27
28
29
            pen_is_required = penalizer.prepare_if_required()
            is_required |= pen_is_required
        self.is_required = is_required
30

31
32
33
34
35
36
37
38
39
40
41
    @property
    def batch(self) -> ScheduleBatch | None:
        return self._batch_ref()

    @batch.setter
    def batch(self, value: Optional[ScheduleBatch]):
        if value is None:
            self._batch_ref = lambda: None
        else:
            self._batch_ref = weakref.ref(value)

42
43
44
    def reqs(self):
        return self.batch.reqs

45
    def cumulate_output_tokens(self, output_ids: torch.Tensor):
46
47
48
49
        """
        Feed the output tokens to the penalizers.

        Args:
50
            output_ids (torch.Tensor): The output tokens.
51
52
        """
        for penalizer in self.penalizers.values():
53
            penalizer.cumulate_output_tokens(output_ids=output_ids)
54
55
56
57
58
59
60
61
62
63
64
65
66

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Apply the penalizers to the logits.
        Note that it may apply the penalizers in-place.

        Args:
            logits (torch.Tensor): The logits to apply the penalizers to.

        Returns:
            torch.Tensor: The logits after applying the penalizers.
        """
        for penalizer in self.penalizers.values():
67
            penalizer.apply(logits)
68

69
    def filter(self, keep_indices: torch.Tensor):
70
71
72
73
        """
        Filter the penalizers based on the indices to keep in the batch.

        Args:
74
            keep_indices (torch.Tensor): Tensor of indices to keep in the batch.
75
        """
76
77
78
        if not self.is_required:
            return

79
        if len(keep_indices) == 0:
80
81
            # No requests left in the batch, fully release orchestrator resources
            self.release()
82
            return
83

84
        is_required = False
85
        for penalizer in self.penalizers.values():
86
            tmp_is_required = penalizer.is_required()
87
88
89
            is_required |= tmp_is_required
            if tmp_is_required:
                penalizer.filter(keep_indices=keep_indices)
90
            else:
91
                penalizer.teardown()
92
        self.is_required = is_required
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    # Resource management helpers
    def release(self) -> None:
        """Release all penalizers and break references so GC can reclaim promptly."""
        for penalizer in self.penalizers.values():
            penalizer.teardown()
        self.penalizers.clear()
        # Break reference to ScheduleBatch
        self._batch_ref = None
        self.is_required = False

    # Context manager support
    def __enter__(self) -> "BatchedPenalizerOrchestrator":
        return self

    def __exit__(self, exc_type, exc, tb) -> None:
        self.release()

111
112
113
114
    def merge(self, their: "BatchedPenalizerOrchestrator"):
        """
        Merge the penalizers of another orchestrator into this one.

115
116
117
118
        Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
        Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
        This step requires the original batch.reqs, before it gets merged with other batch.reqs.

119
120
121
        Args:
            their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
        """
122
123
        if not self.is_required and not their.is_required:
            return
124

125
126
127
        self.is_required = True
        for penalizer, their_penalizer in their.penalizers.items():
            self.penalizers[penalizer].merge(their_penalizer)
128
129
130
131
132
133
134


class _BatchedPenalizer(abc.ABC):
    """
    An abstract class for a batched penalizer.
    """

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
        self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
            weakref.ref(orchestrator)
        )
        self._is_prepared = False

    @property
    def orchestrator(self) -> BatchedPenalizerOrchestrator:
        orch: Optional[BatchedPenalizerOrchestrator] = self._orchestrator_ref()
        # This should never happen, but we need to handle it gracefully
        if orch is None:
            raise RuntimeError(
                "BatchedPenalizerOrchestrator has been garbage-collected"
            )
        return orch

151
152
153
154
155
156
157
    def is_prepared(self) -> bool:
        return self._is_prepared

    def is_required(self) -> bool:
        return self._is_required()

    def prepare(self):
158
        if not self._is_prepared:
159
160
161
162
            self._prepare()
            self._is_prepared = True

    def prepare_if_required(self):
163
        if self._is_required():
164
            self.prepare()
165
166
167
            return True
        else:
            return False
168
169

    def teardown(self):
170
        self._teardown()
171
        self._is_prepared = False
172

173
174
    def cumulate_output_tokens(self, output_ids: torch.Tensor):
        if not self._is_prepared:
175
176
177
178
179
            return

        self._cumulate_output_tokens(output_ids=output_ids)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
180
181
        if not self._is_prepared:
            return
182

183
        self._apply(logits=logits)
184

185
186
    def filter(self, keep_indices: torch.Tensor):
        if not self._is_prepared:
187
188
            return

189
        self._filter(keep_indices=keep_indices)
190
191

    def merge(self, their: "_BatchedPenalizer"):
192
        if not self._is_prepared and not their._is_prepared:
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            return

        self.prepare()
        their.prepare()
        self._merge(their)

    @abc.abstractmethod
    def _is_required(self) -> bool:
        """
        Check if the penalizer is required to be prepared.
        """
        pass

    @abc.abstractmethod
    def _prepare(self):
        """
        Prepare the penalizer.
        Usually, this is where the penalizer initializes its tensors.
        """
        pass

    @abc.abstractmethod
215
    def _cumulate_output_tokens(self, output_ids: torch.Tensor):
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        """
        Cumulate the output tokens.
        Orchestrator will call this function to feed the output tokens to the penalizer.
        """
        pass

    @abc.abstractmethod
    def _apply(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Apply the penalizer to the logits.
        Penalizers can modify the logits in-place if needed.
        """
        pass

    @abc.abstractmethod
231
    def _filter(self, keep_indices: torch.Tensor):
232
233
234
235
236
237
238
239
240
241
242
        """
        Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
        """
        pass

    @abc.abstractmethod
    def _merge(self, their: "_BatchedPenalizer"):
        """
        Merge the penalizer with another penalizer.
        """
        pass
243
244
245
246
247
248
249

    @abc.abstractmethod
    def _teardown(self):
        """
        Teardown the penalizer.
        """
        pass