decode.py 33.1 KB
Newer Older
Byron Hsu's avatar
Byron Hsu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""
Life cycle of a request in the decode server

1. PreallocQueue:
    a. Initialize a receiver for each request
    b. The request handshakes first, and pre-allocate kv once there is available kv.
    c. Move the request to TransferQueue.

2. TransferQueue:
    a. Poll the receiver to check the transfer state
    b. If the transfer has finished, move the request to waiting queue

3. WaitingQueue:
    a. Use the requests in the queue to construct a PrebuiltExtendBatch
    b. Skip the prefill forward but only populate metadata

4. RunningBatch:
    a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
"""

from __future__ import annotations

import logging
24
from collections import deque
Byron Hsu's avatar
Byron Hsu committed
25
from dataclasses import dataclass
26
from http import HTTPStatus
Byron Hsu's avatar
Byron Hsu committed
27
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
Byron Hsu's avatar
Byron Hsu committed
28
29
30
31

import torch
from torch.distributed import ProcessGroup

32
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
33
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
Byron Hsu's avatar
Byron Hsu committed
34
from sglang.srt.disaggregation.utils import (
Byron Hsu's avatar
Byron Hsu committed
35
    FAKE_BOOTSTRAP_HOST,
36
    DisaggregationMode,
37
    KVClassType,
38
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
39
    ReqToMetadataIdxAllocator,
40
41
    TransferBackend,
    get_kv_class,
42
    is_mla_backend,
Byron Hsu's avatar
Byron Hsu committed
43
    kv_to_page_indices,
Byron Hsu's avatar
Byron Hsu committed
44
    poll_and_all_reduce,
45
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
46
)
Byron Hsu's avatar
Byron Hsu committed
47
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
48
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
Byron Hsu's avatar
Byron Hsu committed
49
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
51
from sglang.srt.model_executor.forward_batch_info import ForwardMode
Byron Hsu's avatar
Byron Hsu committed
52
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
from sglang.srt.utils import require_mlp_sync
Byron Hsu's avatar
Byron Hsu committed
54
55
56
57

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
58
    from sglang.srt.managers.schedule_batch import Req
Byron Hsu's avatar
Byron Hsu committed
59
60
61
    from sglang.srt.managers.scheduler import Scheduler


Byron Hsu's avatar
Byron Hsu committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class DecodeReqToTokenPool:
    """
    The difference of DecodeReqToTokenPool and ReqToTokenPool is that
    DecodeReqToTokenPool subscribes memory for pre-allocated requests.

    In ReqToTokenPool, if `--max-running-requests` is 8,
    #pre-allocated + #transfer + #running <= 8, but there are in fact more memory can carry pre-allocated requests.

    In DecodeReqToTokenPool, if `--max-running-requests` is 8,
    #running <= 8, #pre-allocated + #transfer <= pre_alloc_size, so we can use the free memory to pre-allocate requests to unblock prefill.
    """

    def __init__(
        self,
        size: int,
        max_context_len: int,
        device: str,
        enable_memory_saver: bool,
        pre_alloc_size: int,
    ):
        memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=enable_memory_saver
        )

        self.size = size
        self.max_context_len = max_context_len
        self.device = device
        self.pre_alloc_size = pre_alloc_size
90
        with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
Byron Hsu's avatar
Byron Hsu committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            self.req_to_token = torch.zeros(
                (size + pre_alloc_size, max_context_len),
                dtype=torch.int32,
                device=device,
            )

        self.free_slots = list(range(size + pre_alloc_size))

    def write(self, indices, values):
        self.req_to_token[indices] = values

    def available_size(self):
        return len(self.free_slots)

    def alloc(self, need_size: int) -> List[int]:
        if need_size > len(self.free_slots):
            return None

        select_index = self.free_slots[:need_size]
        self.free_slots = self.free_slots[need_size:]
        return select_index

    def free(self, free_index: Union[int, List[int]]):
        if isinstance(free_index, (int,)):
            self.free_slots.append(free_index)
        else:
            self.free_slots.extend(free_index)

    def clear(self):
        self.free_slots = list(range(self.size + self.pre_alloc_size))


Byron Hsu's avatar
Byron Hsu committed
123
124
125
@dataclass
class DecodeRequest:
    req: Req
126
    kv_receiver: BaseKVReceiver
Byron Hsu's avatar
Byron Hsu committed
127
128
129
130
131
132
133
134
135
136
137
138
    waiting_for_input: bool = False
    metadata_buffer_index: int = -1


class DecodePreallocQueue:
    """
    Store the requests that are preallocating.
    """

    def __init__(
        self,
        req_to_token_pool: ReqToTokenPool,
139
        token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
Byron Hsu's avatar
Byron Hsu committed
140
        draft_token_to_kv_pool: Optional[KVCache],
Byron Hsu's avatar
Byron Hsu committed
141
        req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
142
        metadata_buffers: MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
143
144
145
146
147
148
        scheduler: Scheduler,
        transfer_queue: DecodeTransferQueue,
        tree_cache: BasePrefixCache,
        gloo_group: ProcessGroup,
        tp_rank: int,
        tp_size: int,
149
150
        dp_size: int,
        gpu_id: int,
Byron Hsu's avatar
Byron Hsu committed
151
        bootstrap_port: int,
152
153
        max_total_num_tokens: int,
        prefill_pp_size: int,
154
        num_reserved_decode_tokens: int,
155
        transfer_backend: TransferBackend,
Byron Hsu's avatar
Byron Hsu committed
156
157
158
159
    ):
        self.req_to_token_pool = req_to_token_pool
        self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
        self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
Byron Hsu's avatar
Byron Hsu committed
160
        self.draft_token_to_kv_pool = draft_token_to_kv_pool
161
        self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
Byron Hsu's avatar
Byron Hsu committed
162
163
164
165
166
167
168
169
        self.metadata_buffers = metadata_buffers
        self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
        self.scheduler = scheduler
        self.transfer_queue = transfer_queue
        self.tree_cache = tree_cache  # this is always a chunk cache
        self.gloo_group = gloo_group
        self.tp_rank = tp_rank
        self.tp_size = tp_size
170
171
        self.dp_size = dp_size
        self.gpu_id = gpu_id
Byron Hsu's avatar
Byron Hsu committed
172
        self.bootstrap_port = bootstrap_port
173
174
        self.max_total_num_tokens = max_total_num_tokens
        self.prefill_pp_size = prefill_pp_size
175
        self.num_reserved_decode_tokens = num_reserved_decode_tokens
176
        self.transfer_backend = transfer_backend
Byron Hsu's avatar
Byron Hsu committed
177
178
        # Queue for requests pending pre-allocation
        self.queue: List[DecodeRequest] = []
179
180
        self.retracted_queue: List[Req] = []
        self.prefill_pp_size = prefill_pp_size
Byron Hsu's avatar
Byron Hsu committed
181
182
        self.kv_manager = self._init_kv_manager()

183
    def _init_kv_manager(self) -> BaseKVManager:
184
185
186
187
188
189
190
        kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
        kv_args = kv_args_class()

        attn_tp_size = self.tp_size // self.dp_size
        kv_args.engine_rank = self.tp_rank % (attn_tp_size)
        kv_args.decode_tp_size = attn_tp_size
        kv_args.prefill_pp_size = self.prefill_pp_size
Byron Hsu's avatar
Byron Hsu committed
191
192
193
        kv_data_ptrs, kv_data_lens, kv_item_lens = (
            self.token_to_kv_pool.get_contiguous_buf_infos()
        )
Byron Hsu's avatar
Byron Hsu committed
194
        if self.draft_token_to_kv_pool is not None:
195
196
            # We should also transfer draft model kv cache. The indices are
            # always shared with a target model.
Byron Hsu's avatar
Byron Hsu committed
197
198
199
200
201
202
203
            draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
                self.draft_token_to_kv_pool.get_contiguous_buf_infos()
            )
            kv_data_ptrs += draft_kv_data_ptrs
            kv_data_lens += draft_kv_data_lens
            kv_item_lens += draft_kv_item_lens

Byron Hsu's avatar
Byron Hsu committed
204
205
206
207
        kv_args.kv_data_ptrs = kv_data_ptrs
        kv_args.kv_data_lens = kv_data_lens
        kv_args.kv_item_lens = kv_item_lens

208
209
210
        kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
            self.metadata_buffers.get_buf_infos()
        )
211

212
        kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
213
        kv_args.gpu_id = self.scheduler.gpu_id
214
        kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
215
        kv_manager = kv_manager_class(
216
217
218
219
            kv_args,
            DisaggregationMode.DECODE,
            self.scheduler.server_args,
            self.is_mla_backend,
220
        )
Byron Hsu's avatar
Byron Hsu committed
221
222
        return kv_manager

223
    def add(self, req: Req, is_retracted: bool = False) -> None:
Byron Hsu's avatar
Byron Hsu committed
224
        """Add a request to the pending queue."""
225
226
227
228
229
        if self._check_if_req_exceed_kv_capacity(req):
            return

        if is_retracted:
            self.retracted_queue.append(req)
230
        else:
231
232
233
234
235
236
237
238
239
240
241
242
243
            if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
                kv_receiver_class = get_kv_class(
                    TransferBackend.FAKE, KVClassType.RECEIVER
                )
            else:
                kv_receiver_class = get_kv_class(
                    self.transfer_backend, KVClassType.RECEIVER
                )

            kv_receiver = kv_receiver_class(
                mgr=self.kv_manager,
                bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
                bootstrap_room=req.bootstrap_room,
Byron Hsu's avatar
Byron Hsu committed
244
                data_parallel_rank=req.data_parallel_rank,
245
            )
Byron Hsu's avatar
Byron Hsu committed
246

247
248
249
250
251
252
253
254
255
256
257
258
259
260
            self.queue.append(
                DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
            )

    def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
        if len(req.origin_input_ids) > self.max_total_num_tokens:
            message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
            logger.error(message)
            prepare_abort(req, message)
            self.scheduler.stream_output([req], req.return_logprob)
            return True
        return False

    def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
Byron Hsu's avatar
Byron Hsu committed
261
262
        """Add a request to the pending queue."""
        for req in reqs:
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            self.add(req, is_retracted=is_retracted)

    def resume_retracted_reqs(self) -> List[Req]:
        # TODO refactor the scheduling part, reuse with the unified engine logic as much as possible

        # allocate memory
        resumed_reqs = []
        indices_to_remove = set()
        allocatable_tokens = self._allocatable_tokens(count_retracted=False)

        for i, req in enumerate(self.retracted_queue):
            if self.req_to_token_pool.available_size() <= 0:
                break

            required_tokens_for_request = (
                len(req.origin_input_ids)
                + len(req.output_ids)
                + self.num_reserved_decode_tokens
            )
            if required_tokens_for_request > allocatable_tokens:
                break

            resumed_reqs.append(req)
            indices_to_remove.add(i)
            req.is_retracted = False
            self._pre_alloc(req)
            allocatable_tokens -= required_tokens_for_request

            # load from cpu, release the cpu copy
            req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)

        self.retracted_queue = [
            entry
            for i, entry in enumerate(self.retracted_queue)
            if i not in indices_to_remove
        ]

        return resumed_reqs
Byron Hsu's avatar
Byron Hsu committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

    def _update_handshake_waiters(self) -> None:
        if not self.queue:
            return

        if all(decode_req.waiting_for_input for decode_req in self.queue):
            return

        polls = poll_and_all_reduce(
            [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
        )

        for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
            if poll == KVPoll.Bootstrapping:
                pass
            elif poll == KVPoll.WaitingForInput:
                decode_req.waiting_for_input = True
            elif poll == KVPoll.Failed:
319
320
321
322
323
324
325
326
327
328
329
                error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
                try:
                    decode_req.kv_receiver.failure_exception()
                except Exception as e:
                    error_message += f" with exception {e}"
                logger.error(error_message)
                prepare_abort(
                    decode_req.req,
                    error_message,
                    status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
                )
330
331
            else:
                raise ValueError(f"Unexpected poll case: {poll}")
Byron Hsu's avatar
Byron Hsu committed
332
333
334
335
336
337
338
339

    def pop_preallocated(self) -> List[DecodeRequest]:
        """Pop the preallocated requests from the pending queue (FIFO)."""
        self._update_handshake_waiters()

        preallocated_reqs = []
        indices_to_remove = set()

340
341
342
343
344
345
346
347
348
        # We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
        # Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
        retractable_tokens = sum(
            len(r.origin_input_ids) + len(r.output_ids)
            for r in self.scheduler.running_batch.reqs
        )
        allocatable_tokens = self._allocatable_tokens(
            retractable_tokens=retractable_tokens, count_retracted=True
        )
349
        # First, remove all failed requests from the queue
Byron Hsu's avatar
Byron Hsu committed
350
        for i, decode_req in enumerate(self.queue):
351
352
353
354
355
356
            if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
                self.scheduler.stream_output(
                    [decode_req.req], decode_req.req.return_logprob
                )
                indices_to_remove.add(i)

357
        # Then, preallocate the remaining requests if possible
358
359
360
361
        for i, decode_req in enumerate(self.queue):
            if i in indices_to_remove:
                continue

Byron Hsu's avatar
Byron Hsu committed
362
363
364
365
366
367
368
369
370
            if not decode_req.waiting_for_input:
                continue

            if self.req_to_token_pool.available_size() <= 0:
                break

            if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
                break

371
372
373
            # Memory estimation: don't add if the projected memory cannot be met
            # TODO: add new_token ratio
            origin_input_len = len(decode_req.req.origin_input_ids)
Byron Hsu's avatar
Byron Hsu committed
374
            required_tokens_for_request = (
375
                origin_input_len + self.num_reserved_decode_tokens
Byron Hsu's avatar
Byron Hsu committed
376
377
            )

378
379
380
381
382
383
384
385
386
387
            if (
                max(
                    required_tokens_for_request,
                    origin_input_len
                    + decode_req.req.sampling_params.max_new_tokens
                    - retractable_tokens,
                )
                > allocatable_tokens
            ):
                break
Byron Hsu's avatar
Byron Hsu committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
            if required_tokens_for_request > allocatable_tokens:
                break

            allocatable_tokens -= required_tokens_for_request
            self._pre_alloc(decode_req.req)

            kv_indices = (
                self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
                    : len(decode_req.req.origin_input_ids)
                ]
                .cpu()
                .numpy()
            )

            decode_req.metadata_buffer_index = (
                self.req_to_metadata_buffer_idx_allocator.alloc()
            )
            assert decode_req.metadata_buffer_index is not None
Byron Hsu's avatar
Byron Hsu committed
406
407
408
409
            page_indices = kv_to_page_indices(
                kv_indices, self.token_to_kv_pool_allocator.page_size
            )
            decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
Byron Hsu's avatar
Byron Hsu committed
410
411
412
413
414
415
416
417
418
            preallocated_reqs.append(decode_req)
            indices_to_remove.add(i)

        self.queue = [
            entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
        ]

        return preallocated_reqs

419
420
421
422
423
424
    @property
    def num_tokens_pre_allocated(self):
        return sum(
            len(decode_req.req.fill_ids) for decode_req in self.transfer_queue.queue
        )

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    def _allocatable_tokens(
        self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
    ) -> int:
        need_space_for_single_req = (
            max(
                [
                    x.sampling_params.max_new_tokens
                    + len(x.origin_input_ids)
                    - retractable_tokens
                    for x in self.scheduler.running_batch.reqs
                ]
            )
            if retractable_tokens is not None
            and len(self.scheduler.running_batch.reqs) > 0
            else 0
        )

tarinkk's avatar
tarinkk committed
442
        allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
443
444
            # preserve some space for future decode
            self.num_reserved_decode_tokens
Byron Hsu's avatar
Byron Hsu committed
445
446
447
448
            * (
                len(self.scheduler.running_batch.reqs)
                + len(self.transfer_queue.queue)
                + len(self.scheduler.waiting_queue)
449
450
451
            ),
            # make sure each request can finish if reach max_tokens with all other requests retracted
            need_space_for_single_req,
Byron Hsu's avatar
Byron Hsu committed
452
453
454
455
456
457
458
459
460
461
462
463
        )

        # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
        #       the extend batch is not in any queue, so we need to explicitly add the tokens slots here
        if (
            self.scheduler.last_batch
            and self.scheduler.last_batch.forward_mode.is_extend()
        ):
            allocatable_tokens -= self.num_reserved_decode_tokens * len(
                self.scheduler.last_batch.reqs
            )

464
465
466
467
468
469
470
471
472
        if count_retracted:
            allocatable_tokens -= sum(
                [
                    len(req.origin_input_ids)
                    + len(req.output_ids)
                    + self.num_reserved_decode_tokens
                    for req in self.retracted_queue
                ]
            )
Byron Hsu's avatar
Byron Hsu committed
473
474
475
476
477
478
        return allocatable_tokens

    def _pre_alloc(self, req: Req) -> torch.Tensor:
        """Pre-allocate the memory for req_to_token and token_kv_pool"""
        req_pool_indices = self.req_to_token_pool.alloc(1)

479
480
481
        assert (
            req_pool_indices is not None
        ), "req_pool_indices is full! There is a bug in memory estimation."
Byron Hsu's avatar
Byron Hsu committed
482
483

        req.req_pool_idx = req_pool_indices[0]
484

Byron Hsu's avatar
Byron Hsu committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        if self.token_to_kv_pool_allocator.page_size == 1:
            kv_loc = self.token_to_kv_pool_allocator.alloc(
                len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
            )
        else:
            num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
            kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
                prefix_lens=torch.tensor(
                    [0],
                    dtype=torch.int64,
                    device=self.token_to_kv_pool_allocator.device,
                ),
                seq_lens=torch.tensor(
                    [num_tokens],
                    dtype=torch.int64,
                    device=self.token_to_kv_pool_allocator.device,
                ),
                last_loc=torch.tensor(
                    [-1],
                    dtype=torch.int64,
                    device=self.token_to_kv_pool_allocator.device,
                ),
                extend_num_tokens=num_tokens,
            )
509
510
511
512

        assert (
            kv_loc is not None
        ), "KV cache is full! There is a bug in memory estimation."
Byron Hsu's avatar
Byron Hsu committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531

        self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)

        # populate metadata
        req.fill_ids = req.origin_input_ids + req.output_ids
        req.extend_input_len = len(req.origin_input_ids)

        return kv_loc


class DecodeTransferQueue:
    """
    Store the requests that is polling kv
    """

    def __init__(
        self,
        gloo_group: ProcessGroup,
        req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
532
        tp_rank: int,
533
        metadata_buffers: MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
534
535
        scheduler: Scheduler,
        tree_cache: BasePrefixCache,
Byron Hsu's avatar
Byron Hsu committed
536
537
538
539
    ):
        self.queue: List[DecodeRequest] = []
        self.gloo_group = gloo_group
        self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
540
        self.tp_rank = tp_rank
Byron Hsu's avatar
Byron Hsu committed
541
        self.metadata_buffers = metadata_buffers
Byron Hsu's avatar
Byron Hsu committed
542
543
        self.scheduler = scheduler
        self.tree_cache = tree_cache
544
        self.spec_algorithm = scheduler.spec_algorithm
Byron Hsu's avatar
Byron Hsu committed
545

546
547
    def add(self, decode_req: DecodeRequest) -> None:
        self.queue.append(decode_req)
Byron Hsu's avatar
Byron Hsu committed
548

549
550
    def extend(self, decode_reqs: List[DecodeRequest]) -> None:
        self.queue.extend(decode_reqs)
Byron Hsu's avatar
Byron Hsu committed
551

552
    def pop_transferred(self) -> List[Req]:
Byron Hsu's avatar
Byron Hsu committed
553
554
555
556
557
558
559
560
561
562
        if not self.queue:
            return []
        polls = poll_and_all_reduce(
            [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
        )

        transferred_reqs = []
        indices_to_remove = set()
        for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
            if poll == KVPoll.Failed:
563
                error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
                try:
                    decode_req.kv_receiver.failure_exception()
                except Exception as e:
                    error_message += f" with exception {e}"
                logger.error(error_message)
                prepare_abort(
                    decode_req.req,
                    error_message,
                    status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
                )
                self.scheduler.stream_output(
                    [decode_req.req], decode_req.req.return_logprob
                )
                # unlock the kv cache or it will have memory leak
                self.tree_cache.cache_finished_req(decode_req.req)
                indices_to_remove.add(i)
                continue
Byron Hsu's avatar
Byron Hsu committed
581
            elif poll == KVPoll.Success:
582

Byron Hsu's avatar
Byron Hsu committed
583
                idx = decode_req.metadata_buffer_index
584
585
586
587
588
589
                (
                    output_id,
                    output_token_logprobs_val,
                    output_token_logprobs_idx,
                    output_top_logprobs_val,
                    output_top_logprobs_idx,
590
                    output_hidden_states,
591
592
593
                ) = self.metadata_buffers.get_buf(idx)

                decode_req.req.output_ids.append(output_id[0].item())
594
595
                if not self.spec_algorithm.is_none():
                    decode_req.req.hidden_states_tensor = output_hidden_states
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                if decode_req.req.return_logprob:
                    decode_req.req.output_token_logprobs_val.append(
                        output_token_logprobs_val[0].item()
                    )
                    decode_req.req.output_token_logprobs_idx.append(
                        output_token_logprobs_idx[0].item()
                    )
                    decode_req.req.output_top_logprobs_val.append(
                        output_top_logprobs_val[
                            : decode_req.req.top_logprobs_num
                        ].tolist()
                    )
                    decode_req.req.output_top_logprobs_idx.append(
                        output_top_logprobs_idx[
                            : decode_req.req.top_logprobs_num
                        ].tolist()
                    )
613

614
615
                if hasattr(decode_req.kv_receiver, "clear"):
                    decode_req.kv_receiver.clear()
616
617
618
619
620
621
622
623
624
625
626
627

                # special handling for sampling_params.max_new_tokens == 1
                if decode_req.req.sampling_params.max_new_tokens == 1:
                    # finish immediately
                    decode_req.req.check_finished()
                    self.scheduler.stream_output(
                        [decode_req.req], decode_req.req.return_logprob
                    )
                    self.tree_cache.cache_finished_req(decode_req.req)
                else:
                    transferred_reqs.append(decode_req.req)

Byron Hsu's avatar
Byron Hsu committed
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
                indices_to_remove.add(i)
            elif poll in [
                KVPoll.Bootstrapping,
                KVPoll.WaitingForInput,
                KVPoll.Transferring,
            ]:
                pass
            else:
                raise ValueError(f"Unexpected poll case: {poll}")

        for i in indices_to_remove:
            idx = self.queue[i].metadata_buffer_index
            assert idx != -1
            self.req_to_metadata_buffer_idx_allocator.free(idx)

        self.queue = [
            entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
        ]

        return transferred_reqs


class SchedulerDisaggregationDecodeMixin:

652
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
653
    def event_loop_normal_disagg_decode(self: Scheduler):
654
655
656
657
658
659
660
661
662
663
        """A normal scheduler loop for decode worker in disaggregation mode."""

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
            # polling and allocating kv cache
            self.process_decode_queue()
            batch = self.get_next_disagg_decode_batch_to_run()
            self.cur_batch = batch

664
            prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
665

666
667
668
669
            if batch:
                # Generate fake extend output.
                if batch.forward_mode.is_extend():
                    # Note: Logprobs should be handled on the prefill engine.
670
671
672
                    self.stream_output(
                        batch.reqs, any(req.return_logprob for req in batch.reqs)
                    )
673
                    if prepare_mlp_sync_flag:
674
                        self._prepare_idle_batch_and_run(None)
675
                else:
676
677
                    if prepare_mlp_sync_flag:
                        self.prepare_mlp_sync_batch(batch)
678
679
                    result = self.run_batch(batch)
                    self.process_batch_result(batch, result)
680
            elif prepare_mlp_sync_flag:
681
                batch, _ = self._prepare_idle_batch_and_run(None)
682
683

            if batch is None and (
684
685
                len(self.waiting_queue)
                + len(self.disagg_decode_transfer_queue.queue)
686
687
688
689
690
691
                + len(self.disagg_decode_prealloc_queue.queue)
                == 0
            ):
                # When the server is idle, do self-check and re-init some states
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio
692
                self.maybe_sleep_on_idle()
693
694
695

            self.last_batch = batch

696
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
697
    def event_loop_overlap_disagg_decode(self: Scheduler):
698
699
        result_queue = deque()
        self.last_batch: Optional[ScheduleBatch] = None
700
        self.last_batch_in_queue = False  # last batch is modified in-place, so we need another variable to track if it's extend
701
702
703
704
705
706
707
708

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
            # polling and allocating kv cache
            self.process_decode_queue()
            batch = self.get_next_disagg_decode_batch_to_run()
            self.cur_batch = batch
709
710
            last_batch_in_queue = False

711
            prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
712
713
714
715
716

            if batch:
                # Generate fake extend output.
                if batch.forward_mode.is_extend():
                    # Note: Logprobs should be handled on the prefill engine.
717
718
719
                    self.stream_output(
                        batch.reqs, any(req.return_logprob for req in batch.reqs)
                    )
720
                    if prepare_mlp_sync_flag:
721
722
723
724
725
726
                        batch_, result = self._prepare_idle_batch_and_run(
                            None, delay_process=True
                        )
                        if batch_:
                            result_queue.append((batch_.copy(), result))
                            last_batch_in_queue = True
727
                else:
728
729
                    if prepare_mlp_sync_flag:
                        self.prepare_mlp_sync_batch(batch)
730
731
                    result = self.run_batch(batch)
                    result_queue.append((batch.copy(), result))
732
733
734
735
736
737
738
739
740
741

                    if (self.last_batch is None) or (not self.last_batch_in_queue):
                        # Create a dummy first batch to start the pipeline for overlap schedule.
                        # It is now used for triggering the sampling_info_done event.
                        tmp_batch = ScheduleBatch(
                            reqs=None,
                            forward_mode=ForwardMode.DUMMY_FIRST,
                            next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                        )
                        self.set_next_batch_sampling_info_done(tmp_batch)
742
                    last_batch_in_queue = True
743

744
            elif prepare_mlp_sync_flag:
745
746
747
748
749
750
                batch, result = self._prepare_idle_batch_and_run(
                    None, delay_process=True
                )
                if batch:
                    result_queue.append((batch.copy(), result))
                    last_batch_in_queue = True
751
752

            # Process the results of the previous batch but skip if the last batch is extend
753
            if self.last_batch and self.last_batch_in_queue:
754
                tmp_batch, tmp_result = result_queue.popleft()
755
756
757
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
758
759
760
                self.process_batch_result(tmp_batch, tmp_result)

            if batch is None and (
761
762
                len(self.waiting_queue)
                + len(self.disagg_decode_transfer_queue.queue)
763
764
765
766
767
768
                + len(self.disagg_decode_prealloc_queue.queue)
                == 0
            ):
                # When the server is idle, do self-check and re-init some states
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio
769
                self.maybe_sleep_on_idle()
770
771

            self.last_batch = batch
772
            self.last_batch_in_queue = last_batch_in_queue
773

774
    def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
775
        batch = self.prepare_mlp_sync_batch(batch)
Byron Hsu's avatar
Byron Hsu committed
776
777
778
779
        result = None
        if batch:
            result = self.run_batch(batch)
            if not delay_process:
780
                self.process_batch_result(batch, result)
Byron Hsu's avatar
Byron Hsu committed
781
782
        return batch, result

Byron Hsu's avatar
Byron Hsu committed
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    def get_next_disagg_decode_batch_to_run(
        self: Scheduler,
    ) -> Optional[Tuple[ScheduleBatch, bool]]:
        """Create fake completed prefill if possible and merge with running batch"""
        # Merge the prefill batch into the running batch
        last_batch = self.last_batch
        if last_batch and last_batch.forward_mode.is_extend():
            # chunked prefill doesn't happen in decode instance.
            assert self.chunked_req is None
            # Filter finished batches.
            last_batch.filter_batch()
            if not last_batch.is_empty():
                if self.running_batch.is_empty():
                    self.running_batch = last_batch
                else:
                    # merge running_batch with prefill batch
                    self.running_batch.merge_batch(last_batch)

        new_prebuilt_batch = self.get_new_prebuilt_batch()

        ret: Optional[ScheduleBatch] = None
        if new_prebuilt_batch:
            ret = new_prebuilt_batch
        else:
            if self.running_batch.is_empty():
                ret = None
            else:
                self.running_batch = self.update_running_batch(self.running_batch)
                ret = self.running_batch if not self.running_batch.is_empty() else None

        return ret

    def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
        """Create a schedulebatch for fake completed prefill"""
817
818
819
        if self.grammar_queue:
            self.move_ready_grammar_requests()

Byron Hsu's avatar
Byron Hsu committed
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        if len(self.waiting_queue) == 0:
            return None

        curr_batch_size = self.running_batch.batch_size()

        batch_size = min(self.req_to_token_pool.size, self.max_running_requests)

        num_not_used_batch = batch_size - curr_batch_size

        # pop req from waiting queue
        can_run_list: List[Req] = []
        waiting_queue: List[Req] = []

        for i in range(len(self.waiting_queue)):
            req = self.waiting_queue[i]
            # we can only add at least `num_not_used_batch` new batch to the running queue
            if i < num_not_used_batch:
                can_run_list.append(req)
                req.init_next_round_input(self.tree_cache)
            else:
                waiting_queue.append(req)

        self.waiting_queue = waiting_queue
        if len(can_run_list) == 0:
            return None

        # construct a schedule batch with those requests and mark as decode
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
            self.token_to_kv_pool_allocator,
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
            self.spec_algorithm,
            self.server_args.enable_custom_logit_processor,
        )

        # construct fake completed prefill
        new_batch.prepare_for_prebuilt_extend()
        new_batch.process_prebuilt_extend(self.server_args, self.model_config)

        return new_batch

    def process_decode_queue(self: Scheduler):
865
866
867
868
869
870
871
        # try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
        resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
        self.waiting_queue.extend(resumed_reqs)
        if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
            # if there are still retracted requests, we do not allocate new requests
            return

Byron Hsu's avatar
Byron Hsu committed
872
873
874
875
876
        req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
        self.disagg_decode_transfer_queue.extend(req_conns)
        alloc_reqs = (
            self.disagg_decode_transfer_queue.pop_transferred()
        )  # the requests which kv has arrived
877
        self.waiting_queue.extend(alloc_reqs)