orchestrator.py 10.5 KB
Newer Older
1
2
import abc
import dataclasses
3
from typing import List, Set, Type, Union
4
5
6
7
8
9

import torch


@dataclasses.dataclass
class _ReqLike:
10
    origin_input_ids: List[int]
11
12
13
14


@dataclasses.dataclass
class _BatchLike:
15
    reqs: List[_ReqLike]
16
17
18
19
20
21
22
23
24
25
26

    def batch_size(self):
        return len(self.reqs)


class BatchedPenalizerOrchestrator:
    def __init__(
        self,
        vocab_size: int,
        batch: _BatchLike,
        device: str,
27
        Penalizers: Set[Type["_BatchedPenalizer"]],
28
29
30
31
32
33
    ):
        self.vocab_size = vocab_size
        self.batch = batch
        self.device = device
        self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}

34
        is_required = False
35
        for penalizer in self.penalizers.values():
36
37
38
            pen_is_required = penalizer.prepare_if_required()
            is_required |= pen_is_required
        self.is_required = is_required
39

40
41
42
43
        input_ids = [
            torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
            for req in self.reqs()
        ]
44
        if self.is_required:
45
            self.cumulate_input_tokens(input_ids=input_ids)
46
47
48
49
50
51
52

    def reqs(self):
        return self.batch.reqs

    def batch_size(self):
        return self.batch.batch_size()

53
    def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
54
55
56
57
        """
        Feed the input tokens to the penalizers.

        Args:
58
            input_ids (List[torch.Tensor]): The input tokens.
59
60
61
62
63
64
        """
        token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)

        for penalizer in self.penalizers.values():
            penalizer.cumulate_input_tokens(input_ids=token_ids)

65
    def cumulate_output_tokens(self, output_ids: torch.Tensor):
66
67
68
69
        """
        Feed the output tokens to the penalizers.

        Args:
70
            output_ids (torch.Tensor): The output tokens.
71
        """
72
73
74
        if not self.is_required:
            return

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)

        for penalizer in self.penalizers.values():
            penalizer.cumulate_output_tokens(output_ids=token_ids)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Apply the penalizers to the logits.
        Note that it may apply the penalizers in-place.

        Args:
            logits (torch.Tensor): The logits to apply the penalizers to.

        Returns:
            torch.Tensor: The logits after applying the penalizers.
        """
91
92
93
        if not self.is_required:
            return

94
95
96
97
98
99
100
        for penalizer in self.penalizers.values():
            logits = penalizer.apply(logits)

        return logits

    def filter(
        self,
101
        indices_to_keep: List[int],
102
103
104
105
106
107
        indices_tensor_to_keep: torch.Tensor = None,
    ):
        """
        Filter the penalizers based on the indices to keep in the batch.

        Args:
108
            indices_to_keep (List[int]): List of indices to keep in the batch.
109
110
            indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
        """
111
112
113
        if not self.is_required:
            return

114
115
        empty_indices = len(indices_to_keep) == 0

116
        is_required = False
117
        for penalizer in self.penalizers.values():
118
119
120
            tmp_is_required = penalizer.is_required()
            is_required = is_required or tmp_is_required
            if not tmp_is_required or empty_indices:
121
122
123
124
125
126
127
128
129
130
131
132
                penalizer.teardown()
            else:
                # create tensor index only when it's needed
                if indices_tensor_to_keep is None:
                    indices_tensor_to_keep = torch.tensor(
                        indices_to_keep, dtype=torch.int32, device=self.device
                    )

                penalizer.filter(
                    indices_to_keep=indices_to_keep,
                    indices_tensor_to_keep=indices_tensor_to_keep,
                )
133
        self.is_required = is_required
134
135
136
137
138

    def merge(self, their: "BatchedPenalizerOrchestrator"):
        """
        Merge the penalizers of another orchestrator into this one.

139
140
141
142
        Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).
        Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.
        This step requires the original batch.reqs, before it gets merged with other batch.reqs.

143
144
145
        Args:
            their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.
        """
146
147
        if not self.is_required and not their.is_required:
            return
148

149
        self.is_required |= their.is_required
150
151
152
153
154
155
156
157
158
159
160
161
162
        for Penalizer, their_penalizer in their.penalizers.items():
            if Penalizer not in self.penalizers:
                raise ValueError(f"Penalizer {Penalizer} not found in self.penalizers")

            self.penalizers[Penalizer].merge(their_penalizer)


class _TokenIDs:
    """
    A class that wraps token IDs to provide additional utility functions to penalizers.

    Attributes:
        orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
163
        token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
164
165
166
167
168
169
        cached_counts (torch.Tensor): The cached occurrence count tensor.
    """

    def __init__(
        self,
        orchestrator: BatchedPenalizerOrchestrator,
170
        token_ids: Union[torch.Tensor, List[torch.Tensor]],
171
172
173
    ):
        self.orchestrator = orchestrator
        self.token_ids = token_ids
174
        self.cached_counts = None
175
176
177
178
179
180
181
182
183
184
185
186
187

    def occurrence_count(self) -> torch.Tensor:
        """
        Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.

        Returns:
            torch.Tensor: The occurrence count tensor.
        """
        if self.cached_counts is not None:
            return self.cached_counts

        token_ids = self.token_ids

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if isinstance(token_ids, list):
            # TODO: optimize this part
            padded_token_ids = torch.nn.utils.rnn.pad_sequence(
                sequences=token_ids,
                batch_first=True,
                padding_value=self.orchestrator.vocab_size,
            )
            self.cached_counts = torch.zeros(
                size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
                dtype=torch.int64,
                device=self.orchestrator.device,
            ).scatter_add_(
                dim=1,
                index=padded_token_ids,
                src=torch.ones_like(padded_token_ids),
            )[
                :, : self.orchestrator.vocab_size
            ]
        else:
            # TODO: optimize this part. We do not need to create this big tensor every time.
            # We can directly apply the results on the logits.
            self.cached_counts = torch.zeros(
                size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
                device=self.orchestrator.device,
            )
            self.cached_counts[
                torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
            ] = 1
216
217
218
219
220
221
222
223
224
225
226

        return self.cached_counts


class _BatchedPenalizer(abc.ABC):
    """
    An abstract class for a batched penalizer.
    """

    def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
        self.orchestrator = orchestrator
227
        self._is_prepared = False
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

    def is_prepared(self) -> bool:
        return self._is_prepared

    def is_required(self) -> bool:
        return self._is_required()

    def prepare(self):
        if not self.is_prepared():
            self._prepare()
            self._is_prepared = True

    def prepare_if_required(self):
        if self.is_required():
            self.prepare()
243
244
245
            return True
        else:
            return False
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

    def teardown(self):
        if self.is_prepared():
            self._teardown()
            self._is_prepared = False

    def cumulate_input_tokens(self, input_ids: _TokenIDs):
        if not self.is_prepared():
            return

        self._cumulate_input_tokens(input_ids=input_ids)

    def cumulate_output_tokens(self, output_ids: _TokenIDs):
        if not self.is_prepared():
            return

        self._cumulate_output_tokens(output_ids=output_ids)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.is_prepared():
            return logits

        return self._apply(logits=logits)

270
    def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        if not self.is_prepared():
            return

        self._filter(
            indices_to_keep=indices_to_keep,
            indices_tensor_to_keep=indices_tensor_to_keep,
        )

    def merge(self, their: "_BatchedPenalizer"):
        if not self.is_prepared() and not their.is_prepared():
            return

        self.prepare()
        their.prepare()
        self._merge(their)

    @abc.abstractmethod
    def _is_required(self) -> bool:
        """
        Check if the penalizer is required to be prepared.
        """
        pass

    @abc.abstractmethod
    def _prepare(self):
        """
        Prepare the penalizer.
        Usually, this is where the penalizer initializes its tensors.
        """
        pass

    @abc.abstractmethod
    def _teardown(self):
        """
        Tear down the penalizer.
        Usually, this is where the penalizer frees its tensors.
        """
        pass

    @abc.abstractmethod
    def _cumulate_input_tokens(self, input_ids: _TokenIDs):
        """
        Cumulate the input tokens.
        Orchestrator will call this function to feed the input tokens to the penalizer.
        """
        pass

    @abc.abstractmethod
    def _cumulate_output_tokens(self, output_ids: _TokenIDs):
        """
        Cumulate the output tokens.
        Orchestrator will call this function to feed the output tokens to the penalizer.
        """
        pass

    @abc.abstractmethod
    def _apply(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Apply the penalizer to the logits.
        Penalizers can modify the logits in-place if needed.
        """
        pass

    @abc.abstractmethod
335
    def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
336
337
338
339
340
341
342
343
344
345
346
        """
        Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
        """
        pass

    @abc.abstractmethod
    def _merge(self, their: "_BatchedPenalizer"):
        """
        Merge the penalizer with another penalizer.
        """
        pass