schedule_policy.py 12.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""Request scheduler policy"""
15

16
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
17
18
import random
from collections import defaultdict
19
from contextlib import contextmanager
20
from enum import Enum, auto
21
from typing import Dict, List, Optional
22
23

from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
24
25
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
28
# This can prevent the server from being too conservative.
29
30
# Note that this only clips the estimation in the scheduler but does not change the stop
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
31
32
33
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
    os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)
34

Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
class SchedulePolicy:
37
    def __init__(self, policy: str, tree_cache: BasePrefixCache):
38
39
        if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
            # LPM and DFS-weight is meaningless when the tree cache is disabled.
40
            policy = "fcfs"
Mingyi's avatar
Mingyi committed
41

42
        self.policy = policy
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
        self.tree_cache = tree_cache

45
    def calc_priority(self, waiting_queue: List[Req]):
46
47
48
49
50
51
        if len(waiting_queue) > 128 and self.policy == "lpm":
            # Turn off the expensive prefix matching and sorting when the #queue is large.
            policy = "fcfs"
        else:
            policy = self.policy

52
        # Compute matched prefix length
53
        prefix_computed = False
54
        if policy == "lpm" or policy == "dfs-weight":
55
56
57
58
59
            for r in waiting_queue:
                # NOTE: the prefix_indices must always be aligned with last_node
                r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
                    rid=r.rid, key=r.adjust_max_prefix_ids()
                )
60

61
            prefix_computed = True
62

63
        if policy == "lpm":
64
            # Longest Prefix Match
Liangsheng Yin's avatar
Liangsheng Yin committed
65
            waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
66
        elif policy == "fcfs":
67
            # first come first serve
68
            pass
69
        elif policy == "lof":
70
            # longest output first
Liangsheng Yin's avatar
Liangsheng Yin committed
71
            waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
72
        elif policy == "random":
Liangsheng Yin's avatar
Liangsheng Yin committed
73
            random.shuffle(waiting_queue)
74
        elif policy == "dfs-weight":
Lianmin Zheng's avatar
Lianmin Zheng committed
75
            last_node_to_reqs = defaultdict(list)
Liangsheng Yin's avatar
Liangsheng Yin committed
76
            for req in waiting_queue:
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
                last_node_to_reqs[req.last_node].append(req)

            node_to_weight = defaultdict(int)
Liangsheng Yin's avatar
Liangsheng Yin committed
80
81
82
            for node in last_node_to_reqs:
                node_to_weight[node] = len(last_node_to_reqs[node])
            self.calc_weight(self.tree_cache.root_node, node_to_weight)
Lianmin Zheng's avatar
Lianmin Zheng committed
83

84
            waiting_queue.clear()
Liangsheng Yin's avatar
Liangsheng Yin committed
85
            self.get_dfs_priority(
86
87
88
89
                self.tree_cache.root_node,
                node_to_weight,
                last_node_to_reqs,
                waiting_queue,
Lianmin Zheng's avatar
Lianmin Zheng committed
90
91
            )
        else:
92
            raise ValueError(f"Unknown schedule_policy: {policy=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
93

94
95
        return prefix_computed

96
    def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
Lianmin Zheng's avatar
Lianmin Zheng committed
97
        for child in cur_node.children.values():
Liangsheng Yin's avatar
Liangsheng Yin committed
98
            self.calc_weight(child, node_to_weight)
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
            node_to_weight[cur_node] += node_to_weight[child]

101
102
103
104
105
106
107
    def get_dfs_priority(
        self,
        cur_node: TreeNode,
        node_to_priority: Dict,
        last_node_to_reqs: Dict,
        q: List,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
108
109
110
111
112
        childs = [child for child in cur_node.children.values()]
        childs.sort(key=lambda x: -node_to_priority[x])
        for child in childs:
            self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
        q.extend(last_node_to_reqs[cur_node])
113
114


115
116
117
118
119
120
class AddReqResult(Enum):
    CONTINUE = auto()  # Continue to add requests
    NO_TOKEN = auto()  # No token left
    OTHER = auto()  # Other reasons to stop adding requests


121
122
123
class PrefillAdder:
    def __init__(
        self,
124
        tree_cache: BasePrefixCache,
Liangsheng Yin's avatar
Liangsheng Yin committed
125
126
        running_batch: ScheduleBatch,
        new_token_ratio: float,
127
128
        rem_total_tokens: int,
        rem_input_tokens: int,
129
        rem_chunk_tokens: Optional[int],
130
        mixed_with_decode_tokens: int = 0,
131
132
    ):
        self.tree_cache = tree_cache
Liangsheng Yin's avatar
Liangsheng Yin committed
133
134
        self.running_batch = running_batch
        self.new_token_ratio = new_token_ratio
135
136
        self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
        self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
137
        self.rem_chunk_tokens = rem_chunk_tokens
138
139
        if self.rem_chunk_tokens is not None:
            self.rem_chunk_tokens -= mixed_with_decode_tokens
140

Liangsheng Yin's avatar
Liangsheng Yin committed
141
142
        self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
143
        self.req_states = None
144
        self.can_run_list = []
145
        self.new_inflight_req = None
146
147
148
        self.log_hit_tokens = 0
        self.log_input_tokens = 0

Liangsheng Yin's avatar
Liangsheng Yin committed
149
150
151
152
153
154
        if running_batch is not None:
            # Pre-remove the tokens which will be occupied by the running requests
            self.rem_total_tokens -= sum(
                [
                    min(
                        (r.sampling_params.max_new_tokens - len(r.output_ids)),
155
                        CLIP_MAX_NEW_TOKENS_ESTIMATION,
Liangsheng Yin's avatar
Liangsheng Yin committed
156
157
158
159
160
161
                    )
                    * self.new_token_ratio
                    for r in running_batch.reqs
                ]
            )

162
163
164
165
166
167
168
169
170
171
    def budget_state(self):
        if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
            return AddReqResult.NO_TOKEN

        if self.rem_input_tokens <= 0 or (
            self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
        ):
            return AddReqResult.OTHER

        return AddReqResult.CONTINUE
172
173
174
175
176

    def _prefill_one_req(
        self, prefix_len: int, extend_input_len: int, max_new_tokens: int
    ):
        self.rem_total_tokens -= extend_input_len + max_new_tokens
Liangsheng Yin's avatar
Liangsheng Yin committed
177
        self.cur_rem_tokens -= extend_input_len
178
179
180
181
182
183
184
        self.rem_input_tokens -= extend_input_len
        if self.rem_chunk_tokens is not None:
            self.rem_chunk_tokens -= extend_input_len

        self.log_hit_tokens += prefix_len
        self.log_input_tokens += extend_input_len

185
    def add_inflight_req(self, req: Req):
186
187
        truncated = req.extend_input_len > self.rem_chunk_tokens
        req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
188
        req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
189
190
191
192
193
        self.can_run_list.append(req)

        self._prefill_one_req(
            len(req.prefix_indices),
            req.extend_input_len,
194
            (
195
                min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
196
197
198
                if not truncated
                else 0
            ),
199
200
        )

201
202
        # Return if chunked prefill not finished
        return req if truncated else None
203
204

    @contextmanager
205
    def _lock_node(self, last_node: TreeNode):
206
207
208
209
210
211
212
213
        try:
            delta = self.tree_cache.inc_lock_ref(last_node)
            self.rem_total_tokens += delta
            yield None
        finally:
            delta = self.tree_cache.dec_lock_ref(last_node)
            self.rem_total_tokens += delta

Liangsheng Yin's avatar
Liangsheng Yin committed
214
    def add_one_req_ignore_eos(self, req: Req):
Liangsheng Yin's avatar
Liangsheng Yin committed
215
        def add_req_state(r, insert_sort=False):
Liangsheng Yin's avatar
Liangsheng Yin committed
216
217
218
219
220
221
222
223
224
            new_token_ratio = (
                1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
            )
            tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
                r.output_ids
            )
            tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)

            if tokens_left > 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
225
226
227
                if not insert_sort:
                    self.req_states.append((tokens_left, tokens_occupied))
                else:
228
                    i = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
229
230
                    for i in range(len(self.req_states)):
                        if tokens_left <= self.req_states[i][0]:
231
                            break
Liangsheng Yin's avatar
Liangsheng Yin committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                    self.req_states.insert(i, (tokens_left, tokens_occupied))

        if self.req_states is None:
            self.req_states = []
            add_req_state(req)
            if self.running_batch is not None:
                for r in self.running_batch.reqs:
                    add_req_state(r)
            for r in self.can_run_list:
                add_req_state(r)
            self.req_states.sort(key=lambda x: x[0])
        else:
            add_req_state(req, insert_sort=True)

Liangsheng Yin's avatar
Liangsheng Yin committed
246
        cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
Liangsheng Yin's avatar
Liangsheng Yin committed
247
248
249
250
251
252
253
254
        tokens_freed = 0
        for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
            decode_steps = (
                self.req_states[i + 1][0]
                if i + 1 < len(self.req_states)
                else tokens_left
            )
            bs = len(self.req_states) - i
Liangsheng Yin's avatar
Liangsheng Yin committed
255
            if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
256
                return AddReqResult.NO_TOKEN
Liangsheng Yin's avatar
Liangsheng Yin committed
257
            tokens_freed += tokens_occupied
Liangsheng Yin's avatar
Liangsheng Yin committed
258

Ke Bao's avatar
Ke Bao committed
259
260
261
262
        if (
            self.rem_chunk_tokens is None
            or req.extend_input_len <= self.rem_chunk_tokens
        ):
Liangsheng Yin's avatar
Liangsheng Yin committed
263
264
265
266
            self.can_run_list.append(req)
            self._prefill_one_req(
                0,
                req.extend_input_len,
267
                min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
Liangsheng Yin's avatar
Liangsheng Yin committed
268
269
270
271
272
273
274
            )
        else:
            # Chunked prefill
            trunc_len = self.rem_chunk_tokens
            req.extend_input_len = trunc_len
            req.fill_ids = req.fill_ids[:trunc_len]
            self.can_run_list.append(req)
275
            self.new_inflight_req = req
Liangsheng Yin's avatar
Liangsheng Yin committed
276
277
            self._prefill_one_req(0, trunc_len, 0)

278
        return self.budget_state()
Liangsheng Yin's avatar
Liangsheng Yin committed
279

280
    def add_one_req(self, req: Req):
Liangsheng Yin's avatar
Liangsheng Yin committed
281
282
283
        if req.sampling_params.ignore_eos and self.tree_cache.disable:
            return self.add_one_req_ignore_eos(req)

284
        total_tokens = req.extend_input_len + min(
285
            req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
286
        )
287
288
289
290
        input_tokens = req.extend_input_len
        prefix_len = len(req.prefix_indices)

        if total_tokens >= self.rem_total_tokens:
291
            return AddReqResult.NO_TOKEN
292
293

        if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
294
            return AddReqResult.OTHER
295
296
297

        with self._lock_node(req.last_node):
            if total_tokens > self.rem_total_tokens:
298
                return AddReqResult.NO_TOKEN
299
300
301
302

            if (
                self.rem_chunk_tokens is None
                or input_tokens <= self.rem_chunk_tokens
303
304
305
306
307
                or (
                    req.return_logprob
                    and req.normalized_prompt_logprob is None
                    and req.logprob_start_len != len(req.origin_input_ids) - 1
                )
308
309
310
311
312
            ):
                # Non-chunked prefill
                self.can_run_list.append(req)
                self.tree_cache.inc_lock_ref(req.last_node)
                self._prefill_one_req(
313
314
                    prefix_len,
                    input_tokens,
315
316
317
318
                    min(
                        req.sampling_params.max_new_tokens,
                        CLIP_MAX_NEW_TOKENS_ESTIMATION,
                    ),
319
320
                )
            else:
321
                # Chunked prefill
322
323
                trunc_len = self.rem_chunk_tokens
                if trunc_len == 0:
324
                    return AddReqResult.OTHER
325
326

                req.extend_input_len = trunc_len
327
                req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
328
                self.can_run_list.append(req)
329
                self.new_inflight_req = req
330
331
332
                self.tree_cache.inc_lock_ref(req.last_node)
                self._prefill_one_req(prefix_len, trunc_len, 0)

333
        return self.budget_state()