schedule_policy.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
"""Request scheduler policy"""
17

18
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
import random
from collections import defaultdict
21
from contextlib import contextmanager
22
from enum import Enum, auto
23
from typing import Dict, List, Optional
24
25

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
26
27
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
Lianmin Zheng's avatar
Lianmin Zheng committed
28

29
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
30
# This can prevent the server from being too conservative.
31
32
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
33
34
35
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
    os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
36

Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
class SchedulePolicy:
39
    def __init__(self, policy: str, tree_cache: BasePrefixCache):
40
41
        if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
            # LPM and DFS-weight is meaningless when the tree cache is disabled.
42
            policy = "fcfs"
Mingyi's avatar
Mingyi committed
43

44
        self.policy = policy
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
        self.tree_cache = tree_cache

47
    def calc_priority(self, waiting_queue: List[Req]):
48
49
50
51
52
53
        if len(waiting_queue) > 128 and self.policy == "lpm":
            # Turn off the expensive prefix matching and sorting when the #queue is large.
            policy = "fcfs"
        else:
            policy = self.policy

54
        # Compute matched prefix length
55
        prefix_computed = False
56
        if policy == "lpm" or policy == "dfs-weight":
57
58
59
60
61
            for r in waiting_queue:
                # NOTE: the prefix_indices must always be aligned with last_node
                r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
                    rid=r.rid, key=r.adjust_max_prefix_ids()
                )
62

63
            prefix_computed = True
64

65
        if policy == "lpm":
66
            # Longest Prefix Match
Liangsheng Yin's avatar
Liangsheng Yin committed
67
            waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
68
        elif policy == "fcfs":
69
            # first come first serve
70
            pass
71
        elif policy == "lof":
72
            # longest output first
Liangsheng Yin's avatar
Liangsheng Yin committed
73
            waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
74
        elif policy == "random":
Liangsheng Yin's avatar
Liangsheng Yin committed
75
            random.shuffle(waiting_queue)
76
        elif policy == "dfs-weight":
Lianmin Zheng's avatar
Lianmin Zheng committed
77
            last_node_to_reqs = defaultdict(list)
Liangsheng Yin's avatar
Liangsheng Yin committed
78
            for req in waiting_queue:
Lianmin Zheng's avatar
Lianmin Zheng committed
79
80
81
                last_node_to_reqs[req.last_node].append(req)

            node_to_weight = defaultdict(int)
Liangsheng Yin's avatar
Liangsheng Yin committed
82
83
84
            for node in last_node_to_reqs:
                node_to_weight[node] = len(last_node_to_reqs[node])
            self.calc_weight(self.tree_cache.root_node, node_to_weight)
Lianmin Zheng's avatar
Lianmin Zheng committed
85

86
            waiting_queue.clear()
Liangsheng Yin's avatar
Liangsheng Yin committed
87
            self.get_dfs_priority(
88
89
90
91
                self.tree_cache.root_node,
                node_to_weight,
                last_node_to_reqs,
                waiting_queue,
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
            )
        else:
94
            raise ValueError(f"Unknown schedule_policy: {policy=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
95

96
97
        return prefix_computed

98
    def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
Lianmin Zheng's avatar
Lianmin Zheng committed
99
        for child in cur_node.children.values():
Liangsheng Yin's avatar
Liangsheng Yin committed
100
            self.calc_weight(child, node_to_weight)
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
            node_to_weight[cur_node] += node_to_weight[child]

103
104
105
106
107
108
109
    def get_dfs_priority(
        self,
        cur_node: TreeNode,
        node_to_priority: Dict,
        last_node_to_reqs: Dict,
        q: List,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
110
111
112
113
114
        childs = [child for child in cur_node.children.values()]
        childs.sort(key=lambda x: -node_to_priority[x])
        for child in childs:
            self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
        q.extend(last_node_to_reqs[cur_node])
115
116


117
118
119
120
121
122
class AddReqResult(Enum):
    CONTINUE = auto()  # Continue to add requests
    NO_TOKEN = auto()  # No token left
    OTHER = auto()  # Other reasons to stop adding requests


123
124
125
class PrefillAdder:
    def __init__(
        self,
126
        tree_cache: BasePrefixCache,
Liangsheng Yin's avatar
Liangsheng Yin committed
127
128
        running_batch: ScheduleBatch,
        new_token_ratio: float,
129
130
        rem_total_tokens: int,
        rem_input_tokens: int,
131
        rem_chunk_tokens: Optional[int],
132
        mixed_with_decode_tokens: int = 0,
133
134
    ):
        self.tree_cache = tree_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
135
136
        self.running_batch = running_batch
        self.new_token_ratio = new_token_ratio
137
138
        self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
        self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
139
        self.rem_chunk_tokens = rem_chunk_tokens
140
141
        if self.rem_chunk_tokens is not None:
            self.rem_chunk_tokens -= mixed_with_decode_tokens
142

Liangsheng Yin's avatar
Liangsheng Yin committed
143
144
        self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
145
        self.req_states = None
146
        self.can_run_list = []
147
        self.new_inflight_req = None
148
149
150
        self.log_hit_tokens = 0
        self.log_input_tokens = 0

Liangsheng Yin's avatar
Liangsheng Yin committed
151
152
153
154
155
156
        if running_batch is not None:
            # Pre-remove the tokens which will be occupied by the running requests
            self.rem_total_tokens -= sum(
                [
                    min(
                        (r.sampling_params.max_new_tokens - len(r.output_ids)),
157
                        CLIP_MAX_NEW_TOKENS_ESTIMATION,
Liangsheng Yin's avatar
Liangsheng Yin committed
158
159
160
161
162
163
                    )
                    * self.new_token_ratio
                    for r in running_batch.reqs
                ]
            )

164
165
166
167
168
169
170
171
172
173
    def budget_state(self):
        if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
            return AddReqResult.NO_TOKEN

        if self.rem_input_tokens <= 0 or (
            self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
        ):
            return AddReqResult.OTHER

        return AddReqResult.CONTINUE
174
175
176
177
178

    def _prefill_one_req(
        self, prefix_len: int, extend_input_len: int, max_new_tokens: int
    ):
        self.rem_total_tokens -= extend_input_len + max_new_tokens
Liangsheng Yin's avatar
Liangsheng Yin committed
179
        self.cur_rem_tokens -= extend_input_len
180
181
182
183
184
185
186
        self.rem_input_tokens -= extend_input_len
        if self.rem_chunk_tokens is not None:
            self.rem_chunk_tokens -= extend_input_len

        self.log_hit_tokens += prefix_len
        self.log_input_tokens += extend_input_len

187
    def add_inflight_req(self, req: Req):
188
189
        truncated = req.extend_input_len > self.rem_chunk_tokens
        req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
190
        req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
191
192
193
194
195
        self.can_run_list.append(req)

        self._prefill_one_req(
            len(req.prefix_indices),
            req.extend_input_len,
196
            (
197
                min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
198
199
200
                if not truncated
                else 0
            ),
201
202
        )

203
204
        # Return if chunked prefill not finished
        return req if truncated else None
205
206

    @contextmanager
207
    def _lock_node(self, last_node: TreeNode):
208
209
210
211
212
213
214
215
        try:
            delta = self.tree_cache.inc_lock_ref(last_node)
            self.rem_total_tokens += delta
            yield None
        finally:
            delta = self.tree_cache.dec_lock_ref(last_node)
            self.rem_total_tokens += delta

Liangsheng Yin's avatar
Liangsheng Yin committed
216
    def add_one_req_ignore_eos(self, req: Req):
Liangsheng Yin's avatar
Liangsheng Yin committed
217
        def add_req_state(r, insert_sort=False):
Liangsheng Yin's avatar
Liangsheng Yin committed
218
219
220
221
222
223
224
225
226
            new_token_ratio = (
                1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
            )
            tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
                r.output_ids
            )
            tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)

            if tokens_left > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
227
228
229
                if not insert_sort:
                    self.req_states.append((tokens_left, tokens_occupied))
                else:
230
                    i = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
                    for i in range(len(self.req_states)):
                        if tokens_left <= self.req_states[i][0]:
233
                            break
Liangsheng Yin's avatar
Liangsheng Yin committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                    self.req_states.insert(i, (tokens_left, tokens_occupied))

        if self.req_states is None:
            self.req_states = []
            add_req_state(req)
            if self.running_batch is not None:
                for r in self.running_batch.reqs:
                    add_req_state(r)
            for r in self.can_run_list:
                add_req_state(r)
            self.req_states.sort(key=lambda x: x[0])
        else:
            add_req_state(req, insert_sort=True)

Liangsheng Yin's avatar
Liangsheng Yin committed
248
        cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
249
250
251
252
253
254
255
256
        tokens_freed = 0
        for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
            decode_steps = (
                self.req_states[i + 1][0]
                if i + 1 < len(self.req_states)
                else tokens_left
            )
            bs = len(self.req_states) - i
Liangsheng Yin's avatar
Liangsheng Yin committed
257
            if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
258
                return AddReqResult.NO_TOKEN
Liangsheng Yin's avatar
Liangsheng Yin committed
259
            tokens_freed += tokens_occupied
Liangsheng Yin's avatar
Liangsheng Yin committed
260

Ke Bao's avatar
Ke Bao committed
261
262
263
264
        if (
            self.rem_chunk_tokens is None
            or req.extend_input_len <= self.rem_chunk_tokens
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
265
266
267
268
            self.can_run_list.append(req)
            self._prefill_one_req(
                0,
                req.extend_input_len,
269
                min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
Liangsheng Yin's avatar
Liangsheng Yin committed
270
271
272
273
274
275
276
            )
        else:
            # Chunked prefill
            trunc_len = self.rem_chunk_tokens
            req.extend_input_len = trunc_len
            req.fill_ids = req.fill_ids[:trunc_len]
            self.can_run_list.append(req)
277
            self.new_inflight_req = req
Liangsheng Yin's avatar
Liangsheng Yin committed
278
279
            self._prefill_one_req(0, trunc_len, 0)

280
        return self.budget_state()
Liangsheng Yin's avatar
Liangsheng Yin committed
281

282
    def add_one_req(self, req: Req):
Liangsheng Yin's avatar
Liangsheng Yin committed
283
284
285
        if req.sampling_params.ignore_eos and self.tree_cache.disable:
            return self.add_one_req_ignore_eos(req)

286
        total_tokens = req.extend_input_len + min(
287
            req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
288
        )
289
290
291
292
        input_tokens = req.extend_input_len
        prefix_len = len(req.prefix_indices)

        if total_tokens >= self.rem_total_tokens:
293
            return AddReqResult.NO_TOKEN
294
295

        if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
296
            return AddReqResult.OTHER
297
298
299

        with self._lock_node(req.last_node):
            if total_tokens > self.rem_total_tokens:
300
                return AddReqResult.NO_TOKEN
301
302
303
304
305
306
307
308
309
310

            if (
                self.rem_chunk_tokens is None
                or input_tokens <= self.rem_chunk_tokens
                or (req.return_logprob and req.normalized_prompt_logprob is None)
            ):
                # Non-chunked prefill
                self.can_run_list.append(req)
                self.tree_cache.inc_lock_ref(req.last_node)
                self._prefill_one_req(
311
312
                    prefix_len,
                    input_tokens,
313
314
315
316
                    min(
                        req.sampling_params.max_new_tokens,
                        CLIP_MAX_NEW_TOKENS_ESTIMATION,
                    ),
317
318
                )
            else:
319
                # Chunked prefill
320
321
                trunc_len = self.rem_chunk_tokens
                if trunc_len == 0:
322
                    return AddReqResult.OTHER
323
324

                req.extend_input_len = trunc_len
325
                req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
326
                self.can_run_list.append(req)
327
                self.new_inflight_req = req
328
329
330
                self.tree_cache.inc_lock_ref(req.last_node)
                self._prefill_one_req(prefix_len, trunc_len, 0)

331
        return self.budget_state()