sampling_batch_info.py 11.7 KB
Newer Older
1
2
3
from __future__ import annotations

import dataclasses
4
5
import logging
import threading
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
7
8
9
10

import torch

import sglang.srt.sampling.penaltylib as penaltylib
11
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
13
14
15

if TYPE_CHECKING:
    from sglang.srt.managers.schedule_batch import ScheduleBatch

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
logger = logging.getLogger(__name__)

18
19
20

@dataclasses.dataclass
class SamplingBatchInfo:
21
    # Basic batched sampling params
22
23
24
25
26
    temperatures: torch.Tensor
    top_ps: torch.Tensor
    top_ks: torch.Tensor
    min_ps: torch.Tensor

27
    # Whether all requests use greedy sampling
28
29
    is_all_greedy: bool

30
    # Whether any request needs min_p sampling
31
    need_min_p_sampling: bool
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
    # Masking tensors for grammar-guided structured outputs
34
    vocab_size: int
35
    grammars: Optional[List] = None
36
    vocab_mask: Optional[torch.Tensor] = None
37
38
39
40
    apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None

    # An event used for overlap schedule
    sampling_info_done: Optional[threading.Event] = None
41

Liangsheng Yin's avatar
Liangsheng Yin committed
42
    # Penalizer
43
    penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
44
    linear_penalty: torch.Tensor = None
Liangsheng Yin's avatar
Liangsheng Yin committed
45

46
47
48
    # Whether any request has custom logit processor
    has_custom_logit_processor: bool = False
    # Custom parameters
49
    custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
50
    # Custom logit processor
51
52
53
54
    custom_logit_processor: Optional[
        Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
    ] = None

55
56
57
    # Device
    device: str = "cuda"

58
    @classmethod
59
    def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
60
        reqs = batch.reqs
61
        device = batch.device
62
63
        temperatures = (
            torch.tensor(
64
65
66
                [r.sampling_params.temperature for r in reqs],
                dtype=torch.float,
            )
67
68
69
70
71
72
73
74
75
76
77
78
            .view(-1, 1)
            .to(device, non_blocking=True)
        )
        top_ps = torch.tensor(
            [r.sampling_params.top_p for r in reqs], dtype=torch.float
        ).to(device, non_blocking=True)
        top_ks = torch.tensor(
            [r.sampling_params.top_k for r in reqs], dtype=torch.int32
        ).to(device, non_blocking=True)
        min_ps = torch.tensor(
            [r.sampling_params.min_p for r in reqs], dtype=torch.float
        ).to(device, non_blocking=True)
79

80
        # Check if any request has custom logit processor
81
82
83
84
        has_custom_logit_processor = (
            batch.enable_custom_logit_processor  # check the flag first.
            and any(r.custom_logit_processor for r in reqs)  # then check the requests.
        )
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        if has_custom_logit_processor:
            # Merge the same type of custom logit processors together
            processor_dict = {}
            for i, r in enumerate(reqs):
                if r.custom_logit_processor is None:
                    continue
                processor_str = r.custom_logit_processor
                if processor_str not in processor_dict:
                    processor_dict[processor_str] = []
                processor_dict[processor_str].append(i)

            merged_custom_logit_processor = {
                hash(processor_str): (
                    # The deserialized custom logit processor object
                    CustomLogitProcessor.from_str(processor_str),
                    # The mask tensor for the requests that use this custom logit processor
                    torch.zeros(len(reqs), dtype=torch.bool)
                    .scatter_(0, torch.tensor(true_indices), True)
                    .to(device, non_blocking=True),
                )
                for processor_str, true_indices in processor_dict.items()
            }
            custom_params = [r.sampling_params.custom_params for r in reqs]
        else:
            merged_custom_logit_processor = None
            custom_params = None

113
114
115
116
        # Each penalizers will do nothing if they evaluate themselves as not required by looking at
        # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
        # should not add hefty computation overhead other than simple checks.
        #
117
        # While we can choose not to even create the class instances if they are not required, this
118
        # could add additional complexity to the {ScheduleBatch} class, especially we need to
119
        # handle {filter_batch()} and {merge_batch()} cases as well.
120
        penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
121
122
            vocab_size=vocab_size,
            batch=batch,
123
124
125
126
127
            penalizers={
                penaltylib.BatchedFrequencyPenalizer,
                penaltylib.BatchedMinNewTokensPenalizer,
                penaltylib.BatchedPresencePenalizer,
            },
128
        )
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
        ret = cls(
            temperatures=temperatures,
            top_ps=top_ps,
            top_ks=top_ks,
            min_ps=min_ps,
            is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
            need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
            vocab_size=vocab_size,
            penalizer_orchestrator=penalizer_orchestrator,
            has_custom_logit_processor=has_custom_logit_processor,
            custom_params=custom_params,
            custom_logit_processor=merged_custom_logit_processor,
            device=device,
        )
144
145
        return ret

146
147
148
    def __len__(self):
        return len(self.temperatures)

149
    def update_regex_vocab_mask(self):
150
        if not self.grammars:
151
            self.vocab_mask = None
152
            self.apply_mask_func = None
153
154
            return

155
        # Find a grammar from the list
156
        first_grammar = next(grammar for grammar in self.grammars if grammar)
157

158
        # TODO(lianmin): Maybe we can reuse the existing mask?
159
        self.vocab_mask = first_grammar.allocate_vocab_mask(
160
161
            vocab_size=self.vocab_size,
            batch_size=len(self.temperatures),
162
163
            device=self.device,
        )
164
165
166
        self.apply_mask_func = (
            first_grammar.apply_vocab_mask
        )  # force to use static method
167

168
        # Apply the mask
169
        for i, grammar in enumerate(self.grammars):
170
171
172
173
174
            if grammar and not grammar.finished:
                grammar.fill_vocab_mask(self.vocab_mask, i)

        # Move the mask to the device if needed
        self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def update_penalties(self):
        if self.penalizer_orchestrator.is_required:
            self.linear_penalty = torch.zeros(
                (len(self.temperatures), self.vocab_size),
                dtype=torch.float32,
                device=self.temperatures.device,
            )
            self.penalizer_orchestrator.apply(self.linear_penalty)
        else:
            self.linear_penalty = None

    def apply_logits_bias(self, logits: torch.Tensor):
        if self.linear_penalty is not None:
            # Used in the overlap mode
            logits.add_(self.linear_penalty)

        if self.penalizer_orchestrator and self.penalizer_orchestrator.is_required:
            # Used in the non-overlap mode
            self.penalizer_orchestrator.apply(logits)

        if self.vocab_mask is not None:
            self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)

    def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
        self.penalizer_orchestrator.filter(keep_indices_device)

202
        if self.has_custom_logit_processor:
203
            self._filter_batch_custom_logit_processor(keep_indices, keep_indices_device)
204
205
206
207
208
209
210

        for item in [
            "temperatures",
            "top_ps",
            "top_ks",
            "min_ps",
        ]:
211
            value = getattr(self, item, None)
212
            setattr(self, item, value[keep_indices_device])
213

214
    def _filter_batch_custom_logit_processor(
215
        self, keep_indices: List[int], keep_indices_device: torch.Tensor
216
217
218
    ):
        """Filter the custom logit processor and custom params"""
        self.custom_logit_processor = {
219
            k: (p, mask[keep_indices_device])
220
            for k, (p, mask) in self.custom_logit_processor.items()
221
222
            if torch.any(
                mask[keep_indices_device]
223
224
            )  # ignore the custom logit processor whose mask is all False
        }
225
        self.custom_params = [self.custom_params[i] for i in keep_indices]
226

227
228
229
        # If the custom logit processor is an empty dict, set the flag to False,
        # and set the custom logit processor and custom params to None.
        if len(self.custom_logit_processor) == 0:
230
231
232
233
234
235
            self.custom_logit_processor = None
            self.custom_params = None
            self.has_custom_logit_processor = False

    @staticmethod
    def merge_custom_logit_processor(
236
237
        lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
        rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        bs1: int,
        bs2: int,
        device: str,
    ):
        if lhs is None and rhs is None:
            return None
        lhs, rhs = lhs or {}, rhs or {}

        keys = set(lhs.keys()).union(set(rhs.keys()))
        merged_dict = {}

        for k in keys:
            # Get the logit processor object
            processor = lhs[k][0] if k in lhs else rhs[k][0]
            # Get and merge the mask tensors from the two dicts
            left_mask = (
                lhs[k][1]
                if k in lhs
                else torch.zeros(bs1, dtype=torch.bool, device=device)
            )
            right_mask = (
                rhs[k][1]
                if k in rhs
                else torch.zeros(bs2, dtype=torch.bool, device=device)
            )
            merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))

265
266
267
268
269
270
271
            assert merged_dict[k][1].shape[0] == bs1 + bs2, (
                f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
                f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
                f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
                f"\n{lhs=}\n{rhs=}"
            )

272
273
        return merged_dict

274
    def merge_batch(self, other: "SamplingBatchInfo"):
275
        self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        # Merge the custom logit processors and custom params lists
        if self.has_custom_logit_processor or other.has_custom_logit_processor:
            # Merge the custom logit processors
            self.custom_logit_processor = (
                SamplingBatchInfo.merge_custom_logit_processor(
                    self.custom_logit_processor,
                    other.custom_logit_processor,
                    len(self),
                    len(other),
                    self.device,
                )
            )
            # Merge the custom params lists
            self.custom_params = self.custom_params or [None] * len(self)
            other.custom_params = other.custom_params or [None] * len(other)
            self.custom_params.extend(other.custom_params)

            # Set the flag to True if any of the two has custom logit processor
            self.has_custom_logit_processor = True

297
298
299
300
301
302
303
304
305
306
307
        # Note: becasue the __len()__ operator is defined on the temperatures tensor,
        # please make sure any merge operation with len(self) or len(other) is done before
        # the merge operation of the temperatures tensor below.
        for item in [
            "temperatures",
            "top_ps",
            "top_ks",
            "min_ps",
        ]:
            self_val = getattr(self, item, None)
            other_val = getattr(other, item, None)
308
            setattr(self, item, torch.cat([self_val, other_val]))
309

310
311
        self.is_all_greedy |= other.is_all_greedy
        self.need_min_p_sampling |= other.need_min_p_sampling