sampling_batch_info.py 7.24 KB
Newer Older
1
2
3
from __future__ import annotations

import dataclasses
4
from typing import TYPE_CHECKING, Callable, List, Optional
5
6
7
8
9
10
11
12
13
14
15
16

import torch

import sglang.srt.sampling.penaltylib as penaltylib

if TYPE_CHECKING:
    from sglang.srt.managers.schedule_batch import ScheduleBatch


@dataclasses.dataclass
class SamplingBatchInfo:
    # Batched sampling params
17
18
19
20
21
    temperatures: torch.Tensor
    top_ps: torch.Tensor
    top_ks: torch.Tensor
    min_ps: torch.Tensor

22
23
24
    # All requests use greedy sampling
    is_all_greedy: bool

25
26
    # Dispatch in CUDA graph
    need_min_p_sampling: bool
Liangsheng Yin's avatar
Liangsheng Yin committed
27
28

    # Bias Tensors
29
    vocab_size: int
30
    grammars: Optional[List] = None
31
    logit_bias: torch.Tensor = None
32
    vocab_mask: Optional[torch.Tensor] = None
33
    apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
34

Liangsheng Yin's avatar
Liangsheng Yin committed
35
    # Penalizer
36
37
38
    penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
    linear_penalties: Optional[torch.Tensor] = None
    scaling_penalties: Optional[torch.Tensor] = None
Liangsheng Yin's avatar
Liangsheng Yin committed
39

Zhang, Liangang's avatar
Zhang, Liangang committed
40
41
42
    # Device
    device: str = "cuda"

43
    @classmethod
44
45
46
47
48
49
    def from_schedule_batch(
        cls,
        batch: ScheduleBatch,
        vocab_size: int,
        disable_penalizer: bool,
    ):
50
        reqs = batch.reqs
51
        device = batch.device
52
53
        temperatures = (
            torch.tensor(
54
55
56
                [r.sampling_params.temperature for r in reqs],
                dtype=torch.float,
            )
57
58
59
60
61
62
63
64
65
66
67
68
            .view(-1, 1)
            .to(device, non_blocking=True)
        )
        top_ps = torch.tensor(
            [r.sampling_params.top_p for r in reqs], dtype=torch.float
        ).to(device, non_blocking=True)
        top_ks = torch.tensor(
            [r.sampling_params.top_k for r in reqs], dtype=torch.int32
        ).to(device, non_blocking=True)
        min_ps = torch.tensor(
            [r.sampling_params.min_p for r in reqs], dtype=torch.float
        ).to(device, non_blocking=True)
69

70
71
72
73
74
75
        ret = cls(
            temperatures=temperatures,
            top_ps=top_ps,
            top_ks=top_ks,
            min_ps=min_ps,
            need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
76
            is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
77
            vocab_size=vocab_size,
78
            device=device,
79
        )
80
        # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
81
82
83
84
85
86
87

        # Each penalizers will do nothing if they evaluate themselves as not required by looking at
        # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
        # should not add hefty computation overhead other than simple checks.
        #
        # While we choose not to even create the class instances if they are not required, this
        # could add additional complexity to the {ScheduleBatch} class, especially we need to
88
        # handle {filter_batch()} and {merge_batch()} cases as well.
89
90
91
92
93
94
        if disable_penalizer:
            ret.penalizer_orchestrator = None
        else:
            ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
                vocab_size=vocab_size,
                batch=batch,
95
                device=batch.device,
96
97
98
99
100
101
102
                Penalizers={
                    penaltylib.BatchedFrequencyPenalizer,
                    penaltylib.BatchedMinNewTokensPenalizer,
                    penaltylib.BatchedPresencePenalizer,
                    penaltylib.BatchedRepetitionPenalizer,
                },
            )
103
104
105
106
107
108

        # Handle logit bias but only allocate when needed
        ret.logit_bias = None

        return ret

109
110
111
    def __len__(self):
        return len(self.temperatures)

112
    def update_penalties(self):
113
114
115
        if not self.penalizer_orchestrator:
            return

Liangsheng Yin's avatar
Liangsheng Yin committed
116
117
118
119
        self.scaling_penalties = None
        self.linear_penalties = None

        for penalizer in self.penalizer_orchestrator.penalizers.values():
120
121
122
            if not penalizer.is_prepared():
                continue

Liangsheng Yin's avatar
Liangsheng Yin committed
123
            if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
124
                self.scaling_penalties = penalizer.cumulated_repetition_penalties
Liangsheng Yin's avatar
Liangsheng Yin committed
125
            else:
126
127
128
129
130
131
132
133
                if self.linear_penalties is None:
                    bs = self.penalizer_orchestrator.batch.batch_size()
                    self.linear_penalties = torch.zeros(
                        (bs, self.vocab_size),
                        dtype=torch.float32,
                        device=self.device,
                    )
                self.linear_penalties = penalizer.apply(self.linear_penalties)
Liangsheng Yin's avatar
Liangsheng Yin committed
134

135
    def update_regex_vocab_mask(self):
136
        if not self.grammars or not any(grammar for grammar in self.grammars):
137
            self.vocab_mask = None
138
            self.apply_mask = None
139
140
            return

141
142
143
144
145
146
147
        # find a grammar from the list
        grammar = next(grammar for grammar in self.grammars if grammar is not None)

        # maybe we can reuse the existing mask?
        self.vocab_mask = grammar.allocate_vocab_mask(
            vocab_size=self.vocab_size,
            batch_size=len(self.temperatures),
148
149
            device=self.device,
        )
150
151
        self.apply_mask = type(grammar).apply_vocab_mask  # force to use static method

152
153
        for i, grammar in enumerate(self.grammars):
            if grammar is not None:
154
                grammar.fill_vocab_mask(self.vocab_mask, i)
155

156
    def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
157
158
        if self.penalizer_orchestrator:
            self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
159
160
161
162
163
164
165
166

        for item in [
            "temperatures",
            "top_ps",
            "top_ks",
            "min_ps",
            "logit_bias",
        ]:
167
168
169
170
            value = getattr(self, item, None)
            if value is not None:  # logit_bias can be None
                setattr(self, item, value[new_indices])

171
172
    @staticmethod
    def merge_bias_tensor(
Zhang, Liangang's avatar
Zhang, Liangang committed
173
174
175
176
177
178
        lhs: torch.Tensor,
        rhs: torch.Tensor,
        bs1: int,
        bs2: int,
        device: str,
        default: int = 0,
179
180
181
182
183
184
185
186
187
188
    ):
        # bias tensor can be None
        if lhs is not None or rhs is not None:
            shape, dtype = None, None
            if lhs is not None:
                shape, dtype = lhs.shape[1:], lhs.dtype
            else:
                shape, dtype = rhs.shape[1:], rhs.dtype
            with torch.dtype(dtype):
                if lhs is None:
Zhang, Liangang's avatar
Zhang, Liangang committed
189
                    lhs = torch.empty((bs1, *shape), device=device).fill_(default)
190
                if rhs is None:
Zhang, Liangang's avatar
Zhang, Liangang committed
191
                    rhs = torch.empty((bs2, *shape), device=device).fill_(default)
192
193
194
195
            return torch.cat([lhs, rhs])

        return None

196
    def merge_batch(self, other: "SamplingBatchInfo"):
197
198
        if self.penalizer_orchestrator:
            self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
199
200
201
202
203
204
205
206
207
208
209

        for item in [
            "temperatures",
            "top_ps",
            "top_ks",
            "min_ps",
        ]:
            self_val = getattr(self, item, None)
            other_val = getattr(other, item, None)
            setattr(self, item, torch.concat([self_val, other_val]))

210
        self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
211
        self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
Zhang, Liangang's avatar
Zhang, Liangang committed
212
            self.logit_bias, other.logit_bias, len(self), len(other), self.device
213
        )