eagle_worker.py 43.2 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
import logging
2
import os
Yineng Zhang's avatar
Yineng Zhang committed
3
import time
4
from contextlib import contextmanager
5
from typing import List, Optional, Tuple
6
7

import torch
8
from huggingface_hub import snapshot_download
9

10
11
12
13
14
from sglang.srt.distributed import (
    GroupCoordinator,
    get_tp_group,
    patch_tensor_parallel_group,
)
15
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
Lzhang-hub's avatar
Lzhang-hub committed
17
from sglang.srt.managers.mm_utils import embed_mm_inputs
18
19
20
21
22
from sglang.srt.managers.schedule_batch import (
    ScheduleBatch,
    get_last_loc,
    global_server_args_dict,
)
23
24
25
26
27
28
29
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.server_args import ServerArgs
30
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
Yineng Zhang's avatar
Yineng Zhang committed
31
32
33
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
34
35
36
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
    EAGLEDraftExtendCudaGraphRunner,
)
Yineng Zhang's avatar
Yineng Zhang committed
37
38
39
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
40
    EagleVerifyOutput,
Yineng Zhang's avatar
Yineng Zhang committed
41
    assign_draft_cache_locs,
42
    fast_topk,
43
    generate_token_bitmask,
Yineng Zhang's avatar
Yineng Zhang committed
44
45
    select_top_k_tokens,
)
James Liu's avatar
James Liu committed
46
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
47
48
49
from sglang.srt.utils import (
    empty_context,
    get_available_gpu_memory,
50
    get_bool_env_var,
51
52
53
    is_cuda,
    next_power_of_2,
)
54

55
if is_cuda():
56
    from sgl_kernel import segment_packbits
Yineng Zhang's avatar
Yineng Zhang committed
57
58

logger = logging.getLogger(__name__)
59
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
60
61


62
63
64
65
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
    # Draft model doesn't use dp and has its own tp group.
    # We disable mscclpp now because it doesn't support 2 comm groups.
66
    with patch_tensor_parallel_group(tp_group):
67
68
69
        yield


70
71
72
73
74
75
76
77
class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
Cheng Wan's avatar
Cheng Wan committed
78
        moe_ep_rank: int,
79
80
81
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
82
83
84
85
        # Parse arguments
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
86
        self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
87
88
89
90
        self.enable_nan_detection = server_args.enable_nan_detection
        self.gpu_id = gpu_id
        self.device = server_args.device
        self.target_worker = target_worker
91
        self.page_size = server_args.page_size
James Liu's avatar
James Liu committed
92
93
94
        self.speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
95
        self.padded_static_len = -1
96

97
        # Override the context length of the draft model to be the same as the target model.
98
99
        server_args.context_length = target_worker.model_runner.model_config.context_len

100
        # Do not capture cuda graph in `super().__init__()`
101
        # It will be captured later.
102
103
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
104
105
106
107
108
        # Share the allocator with a target worker.
        # Draft and target worker own their own KV cache pools.
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )
109

110
        # Load hot token ids
James Liu's avatar
James Liu committed
111
112
113
114
115
116
117
        if self.speculative_algorithm.is_eagle3():
            if server_args.speculative_token_map is not None:
                logger.warning(
                    "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
                )
            self.hot_token_id = None
        elif server_args.speculative_token_map is not None:
118
            self.hot_token_id = load_token_map(server_args.speculative_token_map)
119
120
121
            server_args.json_model_override_args = (
                f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
            )
122
123
        else:
            self.hot_token_id = None
124

125
        # Init draft worker
126
127
        with empty_context():
            super().__init__(
128
                server_args=server_args,
129
130
                gpu_id=gpu_id,
                tp_rank=tp_rank,
131
                pp_rank=0,  # FIXME
132
                dp_rank=dp_rank,
Cheng Wan's avatar
Cheng Wan committed
133
                moe_ep_rank=moe_ep_rank,
134
                nccl_port=nccl_port,
135
136
137
138
                is_draft_worker=True,
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
Yineng Zhang's avatar
Yineng Zhang committed
139

140
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
James Liu's avatar
James Liu committed
141
142

        if self.speculative_algorithm.is_eagle3():
143
144
145
146
147
148
149
150
151
            # most cases EAGLE3 models don't share lm_head
            # but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
            if (
                hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
                and self.draft_model_runner.model.load_lm_head_from_target
            ):
                self.draft_model_runner.model.set_embed_and_head(embed, head)
            else:
                self.draft_model_runner.model.set_embed(embed)
James Liu's avatar
James Liu committed
152
153

            # grab hot token ids
lukec's avatar
lukec committed
154
155
156
157
158
            if self.draft_model_runner.model.hot_token_id is not None:
                self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
                    embed.device
                )

James Liu's avatar
James Liu committed
159
160
161
162
163
164
165
166
        else:
            if self.hot_token_id is not None:
                head = head.clone()
                self.hot_token_id = self.hot_token_id.to(head.device)
                head.data = head.data[self.hot_token_id]

            # Share the embedding and lm_head
            self.draft_model_runner.model.set_embed_and_head(embed, head)
167

168
        # Init attention backend and cuda graphs
169
170
171
        self.draft_model_runner.server_args.disable_cuda_graph = (
            backup_disable_cuda_graph
        )
172
173
174
175
176
177
        self.draft_tp_context = (
            draft_tp_context if server_args.enable_dp_attention else empty_context
        )
        with self.draft_tp_context(self.draft_model_runner.tp_group):
            self.init_attention_backend()
            self.init_cuda_graphs()
178

179
180
181
182
183
184
        # Some dummy tensors
        self.num_new_pages_per_topk = torch.empty(
            (), dtype=torch.int64, device=self.device
        )
        self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)

185
    def init_attention_backend(self):
186
        # Create multi-step attn backends and cuda graph runners
187
188
189
190

        self.has_prefill_wrapper_verify = False
        self.draft_extend_attn_backend = None

191
192
        # Initialize decode attention backend
        self.draft_attn_backend = self._create_decode_backend()
193

194
195
        # Initialize prefill attention backend
        self.draft_extend_attn_backend = self._create_draft_extend_backend()
196

197
        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    def _create_backend(
        self, backend_name: str, backend_map: dict, error_template: str
    ):
        backend_type = getattr(self.server_args, backend_name)
        if backend_type is None:
            backend_type = self.server_args.attention_backend

        if backend_type not in backend_map:
            raise ValueError(error_template.format(backend_type=backend_type))

        return backend_map[backend_type]()

    def _create_decode_backend(self):
        backend_map = {
            "flashinfer": self._create_flashinfer_decode_backend,
            "triton": self._create_triton_decode_backend,
            "aiter": self._create_aiter_decode_backend,
            "fa3": self._create_fa3_decode_backend,
            "flashmla": self._create_flashmla_decode_backend,
            "trtllm_mha": self._create_trtllm_mha_decode_backend,
            "trtllm_mla": self._create_trtllm_mla_decode_backend,
        }

        return self._create_backend(
            "decode_attention_backend",
            backend_map,
            "EAGLE is not supported in decode attention backend {backend_type}",
        )
227

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    def _create_draft_extend_backend(self):
        backend_map = {
            "flashinfer": self._create_flashinfer_prefill_backend,
            "triton": self._create_triton_prefill_backend,
            "aiter": self._create_aiter_prefill_backend,
            "fa3": self._create_fa3_prefill_backend,
            "trtllm_mha": self._create_trtllm_mha_prefill_backend,
            "trtllm_mla": self._create_trtllm_mla_prefill_backend,
        }

        return self._create_backend(
            "prefill_attention_backend",
            backend_map,
            "EAGLE is not supported in prefill attention backend {backend_type}",
        )
243

244
245
246
247
    def _create_flashinfer_decode_backend(self):
        if not global_server_args_dict["use_mla_backend"]:
            from sglang.srt.layers.attention.flashinfer_backend import (
                FlashInferMultiStepDraftBackend,
248
249
            )

250
251
252
            self.has_prefill_wrapper_verify = True
            return FlashInferMultiStepDraftBackend(
                self.draft_model_runner, self.topk, self.speculative_num_steps
253
            )
254
255
256
        else:
            from sglang.srt.layers.attention.flashinfer_mla_backend import (
                FlashInferMLAMultiStepDraftBackend,
257
258
259
            )

            self.has_prefill_wrapper_verify = True
260
261
            return FlashInferMLAMultiStepDraftBackend(
                self.draft_model_runner, self.topk, self.speculative_num_steps
262
263
            )

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    def _create_triton_decode_backend(self):
        from sglang.srt.layers.attention.triton_backend import (
            TritonMultiStepDraftBackend,
        )

        return TritonMultiStepDraftBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

    def _create_aiter_decode_backend(self):
        from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend

        return AiterMultiStepDraftBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

    def _create_fa3_decode_backend(self):
        from sglang.srt.layers.attention.flashattention_backend import (
            FlashAttentionMultiStepBackend,
        )

        return FlashAttentionMultiStepBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

    def _create_flashmla_decode_backend(self):
        from sglang.srt.layers.attention.flashmla_backend import (
            FlashMLAMultiStepDraftBackend,
        )

        return FlashMLAMultiStepDraftBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

    def _create_trtllm_mha_decode_backend(self):
        from sglang.srt.layers.attention.trtllm_mha_backend import (
            TRTLLMHAAttnMultiStepDraftBackend,
        )

        self.has_prefill_wrapper_verify = True
        return TRTLLMHAAttnMultiStepDraftBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

    def _create_trtllm_mla_decode_backend(self):
        if not global_server_args_dict["use_mla_backend"]:
            raise ValueError(
                "trtllm_mla backend requires MLA model (use_mla_backend=True)."
312
            )
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        from sglang.srt.layers.attention.trtllm_mla_backend import (
            TRTLLMMLAMultiStepDraftBackend,
        )

        self.has_prefill_wrapper_verify = True
        return TRTLLMMLAMultiStepDraftBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

    def _create_flashinfer_prefill_backend(self):
        if not global_server_args_dict["use_mla_backend"]:
            from sglang.srt.layers.attention.flashinfer_backend import (
                FlashInferAttnBackend,
327
            )
328
329

            return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
330
        else:
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            from sglang.srt.layers.attention.flashinfer_mla_backend import (
                FlashInferMLAAttnBackend,
            )

            return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)

    def _create_triton_prefill_backend(self):
        from sglang.srt.layers.attention.triton_backend import TritonAttnBackend

        return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)

    def _create_aiter_prefill_backend(self):
        from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend

        return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)

    def _create_fa3_prefill_backend(self):
        from sglang.srt.layers.attention.flashattention_backend import (
            FlashAttentionBackend,
        )

        return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)

    def _create_trtllm_mha_prefill_backend(self):
        from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend

        return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)

    def _create_trtllm_mla_prefill_backend(self):
        if not global_server_args_dict["use_mla_backend"]:
361
            raise ValueError(
362
                "trtllm_mla backend requires MLA model (use_mla_backend=True)."
363
            )
364

365
366
367
        from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend

        return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
Yineng Zhang's avatar
Yineng Zhang committed
368
369

    def init_cuda_graphs(self):
370
        """Capture cuda graphs."""
Yineng Zhang's avatar
Yineng Zhang committed
371
        self.cuda_graph_runner = None
372
        self.cuda_graph_runner_for_draft_extend = None
Yineng Zhang's avatar
Yineng Zhang committed
373
374
375
376

        if self.server_args.disable_cuda_graph:
            return

377
        # Capture draft
378
        tic = time.perf_counter()
379
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
380
        logger.info(
381
            f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
382
        )
Yineng Zhang's avatar
Yineng Zhang committed
383
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
384
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
385
        logger.info(
386
            f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
387
        )
388

389
390
        # Capture extend
        if self.draft_extend_attn_backend:
391
392
393
394
395
396
397
398
399
400
            tic = time.perf_counter()
            before_mem = get_available_gpu_memory(self.device, self.gpu_id)
            logger.info(
                f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
            )
            self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
                self
            )
            after_mem = get_available_gpu_memory(self.device, self.gpu_id)
            logger.info(
401
                f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
402
            )
403

404
405
406
407
408
409
    @property
    def draft_model_runner(self):
        return self.model_runner

    def forward_batch_speculative_generation(
        self, batch: ScheduleBatch
Cheng Wan's avatar
Cheng Wan committed
410
    ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
411
412
        """Run speculative decoding forward.

413
414
        NOTE: Many states of batch is modified as you go through. It is not guaranteed that
        the final output batch have the same state as the input.
415
416
417
418

        Args:
            batch: The batch to run forward. The state of the batch is modified as it runs.
        Returns:
419
420
            A tuple of the final logit output of the target model, next tokens accepted,
            the batch id (used for overlap schedule), and number of accepted tokens.
421
        """
422
423
424
425
426
427
428
429
430
431
        if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
            logits_output, next_token_ids, bid, seq_lens_cpu = (
                self.forward_target_extend(batch)
            )
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                self.forward_draft_extend(
                    batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
                )
            return logits_output, next_token_ids, bid, 0, False
        else:
432
            with self.draft_tp_context(self.draft_model_runner.tp_group):
433
                spec_info = self.draft(batch)
434
435
            logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
                self.verify(batch, spec_info)
436
            )
437

Cheng Wan's avatar
Cheng Wan committed
438
439
440
441
442
443
444
445
446
447
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                # NOTE: We should use `check_forward_draft_extend_after_decode`
                # when DP attention is enabled, but it is slow. Skip it for now.
                if (
                    self.server_args.enable_dp_attention
                    or batch.spec_info.verified_id.shape[0] > 0
                ):
                    # decode is not finished
                    self.forward_draft_extend_after_decode(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
            return (
                logits_output,
450
451
452
                verify_output.verified_id,
                model_worker_batch.bid,
                sum(verify_output.accept_length_per_req_cpu),
453
                can_run_cuda_graph,
Lianmin Zheng's avatar
Lianmin Zheng committed
454
            )
455

456
    def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
Cheng Wan's avatar
Cheng Wan committed
457
        local_need_forward = batch.spec_info.verified_id.shape[0] > 0
458
        if not self.server_args.enable_dp_attention:
459
            return local_need_forward
460
461
462
463
464
465
466
467
468
469
470
471

        global_need_forward = torch.tensor(
            [
                (local_need_forward),
            ],
            dtype=torch.int64,
        )
        torch.distributed.all_reduce(
            global_need_forward, group=get_tp_group().cpu_group
        )
        global_need_forward_cnt = global_need_forward[0].item()
        need_forward = global_need_forward_cnt > 0
472
        return need_forward
473
474
475

    def forward_target_extend(
        self, batch: ScheduleBatch
Cheng Wan's avatar
Cheng Wan committed
476
    ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        """Run the target extend.

        Args:
            batch: The batch to run. States could be modified.

        Returns:
            logits_output: The output of logits. It will contain the full hidden states.
            next_token_ids: Next token ids generated.
            bid: The model batch ID. Used for overlap schedule.
        """
        # Forward with the target model and get hidden states.
        # We need the full hidden states to prefill the KV cache of the draft model.
        model_worker_batch = batch.get_model_worker_batch()
        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
491
        logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
492
493
            model_worker_batch
        )
494
495
496
497
498
499
        return (
            logits_output,
            next_token_ids,
            model_worker_batch.bid,
            model_worker_batch.seq_lens_cpu,
        )
500

501
    def _draft_preprocess_decode(self, batch: ScheduleBatch):
Yineng Zhang's avatar
Yineng Zhang committed
502
503
504
505
        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

506
507
508
509
510
511
512
        # Accumulate penalty
        if batch.sampling_info.penalizer_orchestrator.is_required:
            # This is a relaxed version of penalties for speculative decoding.
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                spec_info.verified_id.to(torch.int64)
            )

Yineng Zhang's avatar
Yineng Zhang committed
513
        # Allocate cache locations
514
515
516
        # Layout of the out_cache_loc
        # [       topk 0         ] [       topk 1         ]
        # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
517
518
        if self.page_size == 1:
            out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
519
                num_seqs * self.speculative_num_steps * self.topk, backup_state=True
520
521
522
            )
        else:
            if self.topk == 1:
523
524
525
526
527
528
                prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
                    batch.req_to_token_pool.req_to_token,
                    batch.req_pool_indices,
                    batch.seq_lens,
                    self.speculative_num_steps,
                )
529
530
531
532
533
534
535
536
537
538
539
540
                extend_num_tokens = num_seqs * self.speculative_num_steps
            else:
                # In this case, the last partial page needs to be duplicated.
                # KV cache layout in batch.req_to_token_pool.req_to_token:
                #
                # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
                #    prefix     top-k = 0    tok-k = 1    top-k = 2
                #
                #  "-" means prefix tokens
                #  "x" means speculative draft tokens
                #  "." means padded tokens

541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
                # TODO(lmzheng): The current implementation is still a fake support
                # for page size > 1. In the `assign_draft_cache_locs` below,
                # we directly move the indices instead of the real kv cache.
                # This only works when the kernel backend runs with page size = 1.
                # If the kernel backend runs with page size > 1, we need to
                # duplicate the real KV cache. The overhead of duplicating KV
                # cache seems okay because the draft KV cache only has one layer.
                # see a related copy operation in MHATokenToKVPool::move_kv_cache.

                (
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    self.num_new_pages_per_topk,
                    self.extend_lens,
                ) = get_last_loc_large_page_size_large_top_k(
                    batch.req_to_token_pool.req_to_token,
                    batch.req_pool_indices,
                    batch.seq_lens,
                    self.speculative_num_steps,
                    self.topk,
                    self.page_size,
563
                )
564
565
566
567

                # TODO(lmzheng): remove this device sync
                extend_num_tokens = torch.sum(self.extend_lens).item()

568
569
570
571
572
573
574
575
576
577
            out_cache_loc, token_to_kv_pool_state_backup = (
                batch.alloc_paged_token_slots_extend(
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    extend_num_tokens,
                    backup_state=True,
                )
            )

Yineng Zhang's avatar
Yineng Zhang committed
578
579
580
581
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
582
583
            self.extend_lens,
            self.num_new_pages_per_topk,
Yineng Zhang's avatar
Yineng Zhang committed
584
585
586
587
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
588
            self.page_size,
589
590
            next_power_of_2(num_seqs),
            next_power_of_2(self.speculative_num_steps),
Yineng Zhang's avatar
Yineng Zhang committed
591
        )
592
593
594
595
596
597
598

        if self.page_size > 1 and self.topk > 1:
            # Remove padded slots
            out_cache_loc = out_cache_loc[
                : num_seqs * self.topk * self.speculative_num_steps
            ]

Yineng Zhang's avatar
Yineng Zhang committed
599
600
        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
601
        batch.return_hidden_states = False
Yineng Zhang's avatar
Yineng Zhang committed
602
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
603
604
605
606
607
608
        self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

    def _draft_preprocess_idle(self, batch: ScheduleBatch):
        batch.spec_info = EagleDraftInput.create_idle_input(
            device=self.device,
            hidden_size=self.model_config.hidden_size,
609
            dtype=self.model_config.dtype,
610
611
612
613
614
615
616
617
618
619
620
621
            topk=self.topk,
            capture_hidden_mode=CaptureHiddenMode.LAST,
        )

    def draft(self, batch: ScheduleBatch):
        # Parse args
        if batch.forward_mode.is_idle():
            self._draft_preprocess_idle(batch)
        else:
            self._draft_preprocess_decode(batch)

        spec_info = batch.spec_info
Cheng Wan's avatar
Cheng Wan committed
622
        assert isinstance(spec_info, EagleDraftInput)
623

Yineng Zhang's avatar
Yineng Zhang committed
624
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
Cheng Wan's avatar
Cheng Wan committed
625
626
        spec_info.num_tokens_per_batch = self.topk
        spec_info.num_tokens_for_logprob_per_batch = self.topk
627
        batch.return_hidden_states = False
628
629

        # Get forward batch
Yineng Zhang's avatar
Yineng Zhang committed
630
        model_worker_batch = batch.get_model_worker_batch()
631
        assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
632
633
634
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
Yineng Zhang's avatar
Yineng Zhang committed
635
636
637
638
639
640
641
642
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )
        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
Cheng Wan's avatar
Cheng Wan committed
643
            forward_batch.can_run_dp_cuda_graph = False
644
645
646
            if not forward_batch.forward_mode.is_idle():
                # Initialize attention backend
                self.draft_attn_backend.init_forward_metadata(forward_batch)
Yineng Zhang's avatar
Yineng Zhang committed
647
648
649
            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

650
651
652
653
654
655
        if batch.forward_mode.is_idle():
            return EagleVerifyInput.create_idle_input(
                self.topk,
                self.speculative_num_steps,
                self.speculative_num_draft_tokens,
            )
656

657
658
659
660
661
662
663
664
        (
            tree_mask,
            position,
            retrive_index,
            retrive_next_token,
            retrive_next_sibling,
            draft_tokens,
        ) = build_tree_kernel_efficient(
Yineng Zhang's avatar
Yineng Zhang committed
665
666
667
668
669
670
671
672
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
673
            self.speculative_num_draft_tokens,
Yineng Zhang's avatar
Yineng Zhang committed
674
        )
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690

        return EagleVerifyInput(
            draft_token=draft_tokens,
            custom_mask=tree_mask,
            positions=position,
            retrive_index=retrive_index,
            retrive_next_token=retrive_next_token,
            retrive_next_sibling=retrive_next_sibling,
            retrive_cum_len=None,
            spec_steps=self.speculative_num_steps,
            topk=self.topk,
            draft_token_num=self.server_args.speculative_num_draft_tokens,
            capture_hidden_mode=CaptureHiddenMode.FULL,
            seq_lens_sum=forward_batch.seq_lens_sum,
            seq_lens_cpu=forward_batch.seq_lens_cpu,
        )
Yineng Zhang's avatar
Yineng Zhang committed
691
692
693
694

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
Cheng Wan's avatar
Cheng Wan committed
695
        assert isinstance(spec_info, EagleDraftInput)
Yineng Zhang's avatar
Yineng Zhang committed
696
697
698
699
700
701
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )
702
703
        if self.hot_token_id is not None:
            topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
704

705
706
707
708
709
710
711
        out_cache_loc = out_cache_loc.reshape(
            forward_batch.batch_size, self.topk, self.speculative_num_steps
        )
        out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
            self.speculative_num_steps, -1
        )

Yineng Zhang's avatar
Yineng Zhang committed
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

727
            # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
728
729
730
            if i == self.speculative_num_steps - 1:
                break

Yineng Zhang's avatar
Yineng Zhang committed
731
732
            # Set inputs
            forward_batch.input_ids = input_ids
733
734
735
736
737
738
739
740
            # This is a temporary fix for the case that the user is using standalone
            # speculative decoding and the draft model architecture is gpt-oss. gpt-oss
            # rope kernel needs cache_loc to be contiguous.
            if (
                self.server_args.speculative_algorithm == "STANDALONE"
                and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
            ):
                out_cache_loc = out_cache_loc.contiguous()
741
            forward_batch.out_cache_loc = out_cache_loc[i]
Yineng Zhang's avatar
Yineng Zhang committed
742
743
744
745
746
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
Cheng Wan's avatar
Cheng Wan committed
747
748
            logits_output, _ = self.draft_model_runner.forward(
                forward_batch, skip_attn_backend_init=True
Yineng Zhang's avatar
Yineng Zhang committed
749
            )
750
            self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
751
752
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
753
754
            if self.hot_token_id is not None:
                topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
755
756
757
758
759
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
760
        spec_info.prepare_for_verify(batch, self.page_size)
761
        batch.return_hidden_states = False
762
763
764
765
766
        batch.forward_mode = (
            ForwardMode.TARGET_VERIFY
            if not batch.forward_mode.is_idle()
            else ForwardMode.IDLE
        )
Yineng Zhang's avatar
Yineng Zhang committed
767
        batch.spec_info = spec_info
Cheng Wan's avatar
Cheng Wan committed
768

769
770
771
        model_worker_batch = batch.get_model_worker_batch(
            seq_lens_cpu_cache=spec_info.seq_lens_cpu
        )
772
        assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
773
774
775
776
777
778
779
780
781

        if batch.has_grammar:
            retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
            retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
            draft_tokens_cpu = spec_info.draft_token.view(
                spec_info.retrive_next_token.shape
            ).cpu()

        # Forward
782
783
784
785
        logits_output, _, can_run_cuda_graph = (
            self.target_worker.forward_batch_generation(
                model_worker_batch, skip_sample=True
            )
786
        )
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803

        vocab_mask = None
        if batch.has_grammar:
            # Generate the logit mask for structured output.
            # Overlap the CPU operations for bitmask generation with the forward pass.
            vocab_mask = generate_token_bitmask(
                batch.reqs,
                spec_info,
                retrieve_next_token_cpu,
                retrieve_next_sibling_cpu,
                draft_tokens_cpu,
                batch.sampling_info.vocab_size,
            )

            if vocab_mask is not None:
                assert spec_info.grammar is not None
                vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
804
                # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
805
806
807
                # and will be applied to produce wrong results
                batch.sampling_info.vocab_mask = None

808
        self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
809
        spec_info.hidden_states = logits_output.hidden_states
810
        res: EagleVerifyOutput = spec_info.verify(
811
812
813
814
            batch,
            logits_output,
            self.token_to_kv_pool_allocator,
            self.page_size,
815
            vocab_mask,
816
817
818
        )

        # Post process based on verified outputs.
819
        # Pick indices that we care (accepted)
820
        logits_output.next_token_logits = logits_output.next_token_logits[
821
            res.accepted_indices
822
        ]
823
        logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
824

825
826
827
        if batch.return_logprob:
            self.add_logprob_values(batch, res, logits_output)

828
        # Prepare the batch for the next draft forwards.
829
830
831
        batch.forward_mode = (
            ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
        )
832
        batch.spec_info = res.draft_input
833

834
        return logits_output, res, model_worker_batch, can_run_cuda_graph
835

836
837
838
839
840
841
842
843
844
845
    def add_logprob_values(
        self,
        batch: ScheduleBatch,
        res: EagleVerifyOutput,
        logits_output: LogitsProcessorOutput,
    ):
        # Extract args
        logits_output = res.logits_output
        top_logprobs_nums = batch.top_logprobs_nums
        token_ids_logprobs = batch.token_ids_logprobs
846
847
        accepted_indices = res.accepted_indices
        assert len(accepted_indices) == len(logits_output.next_token_logits)
848

849
850
851
852
853
        temperatures = batch.sampling_info.temperatures
        num_draft_tokens = batch.spec_info.draft_token_num
        # acceptance indices are the indices in a "flattened" batch.
        # dividing it to num_draft_tokens will yield the actual batch index.
        temperatures = temperatures[accepted_indices // num_draft_tokens]
854
855
856
857
858
859
860
861
        if RETURN_ORIGINAL_LOGPROB:
            logprobs = torch.nn.functional.log_softmax(
                logits_output.next_token_logits, dim=-1
            )
        else:
            logprobs = torch.nn.functional.log_softmax(
                logits_output.next_token_logits / temperatures, dim=-1
            )
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        batch_next_token_ids = res.verified_id
        num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]

        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)

        # Extract logprobs
        if any(x > 0 for x in top_logprobs_nums):
            (
                logits_output.next_token_top_logprobs_val,
                logits_output.next_token_top_logprobs_idx,
878
879
880
881
            ) = get_top_logprobs(
                logprobs,
                top_logprobs_nums_repeat_interleaved,
            )
882
883
884
885
886

        if any(x is not None for x in token_ids_logprobs):
            (
                logits_output.next_token_token_ids_logprobs_val,
                logits_output.next_token_token_ids_logprobs_idx,
887
888
889
890
            ) = get_token_ids_logprobs(
                logprobs,
                token_ids_logprobs_repeat_interleaved,
            )
891
892
893
894
895
896

        logits_output.next_token_logprobs = logprobs[
            torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
            batch_next_token_ids,
        ]

897
        # Add output logprobs to the request
898
899
900
        pt = 0
        next_token_logprobs = logits_output.next_token_logprobs.tolist()
        verified_ids = batch_next_token_ids.tolist()
901
        for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
902
903
904
905
906
907
908
909
910
911
912
913
914
            for _ in range(num_tokens):
                if req.return_logprob:
                    req.output_token_logprobs_val.append(next_token_logprobs[pt])
                    req.output_token_logprobs_idx.append(verified_ids[pt])
                    if req.top_logprobs_num > 0:
                        req.output_top_logprobs_val.append(
                            res.logits_output.next_token_top_logprobs_val[pt]
                        )
                        req.output_top_logprobs_idx.append(
                            res.logits_output.next_token_top_logprobs_idx[pt]
                        )
                pt += 1

915
916
917
918
    def forward_draft_extend(
        self,
        batch: ScheduleBatch,
        hidden_states: torch.Tensor,
Cheng Wan's avatar
Cheng Wan committed
919
920
        next_token_ids: torch.Tensor,
        seq_lens_cpu: Optional[torch.Tensor],
921
922
923
924
925
926
927
928
929
930
931
    ):
        """Run draft model extend. This API modifies the states of the batch.

        Args:
            batch: The batch to run.
            hidden_states: Hidden states from the target model forward
            next_token_ids: Next token ids generated from the target forward.
        """
        batch.spec_info = EagleDraftInput(
            hidden_states=hidden_states,
            verified_id=next_token_ids,
Cheng Wan's avatar
Cheng Wan committed
932
933
            num_tokens_per_batch=1,
            num_tokens_for_logprob_per_batch=1,
934
        )
935
        batch.return_hidden_states = False
Yineng Zhang's avatar
Yineng Zhang committed
936
937
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
938
939
940
        model_worker_batch = batch.get_model_worker_batch(
            seq_lens_cpu_cache=seq_lens_cpu
        )
941
942
943
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
944
        forward_batch.return_logprob = False
945
        logits_output, _ = self.draft_model_runner.forward(forward_batch)
946
947
948
949
        self._detect_nan_if_needed(logits_output)
        assert isinstance(forward_batch.spec_info, EagleDraftInput)
        assert forward_batch.spec_info is batch.spec_info
        self.capture_for_decode(logits_output, forward_batch.spec_info)
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
        has_finished, unfinished_req_index = False, []
        for i, req in enumerate(batch.reqs):
            if req.finished():
                has_finished = True
            else:
                unfinished_req_index.append(i)
        if has_finished:
            unfinished_index_device = torch.tensor(
                unfinished_req_index,
                dtype=torch.int64,
                device=batch.spec_info.topk_p.device,
            )
            batch.spec_info.filter_batch(
                unfinished_index_device, has_been_filtered=False
            )
965

966
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
Cheng Wan's avatar
Cheng Wan committed
967
        assert isinstance(batch.spec_info, EagleDraftInput)
968
        # Backup fields that will be modified in-place
969
970
971
972
        seq_lens_backup = batch.seq_lens.clone()
        req_pool_indices_backup = batch.req_pool_indices
        accept_length_backup = batch.spec_info.accept_length
        return_logprob_backup = batch.return_logprob
Cheng Wan's avatar
Cheng Wan committed
973

974
        input_is_idle = batch.forward_mode.is_idle()
Cheng Wan's avatar
Cheng Wan committed
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003

        if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
            batch = batch.copy()
            batch.prepare_for_idle()
            hidden_size = (
                self.model_config.hidden_size * 3
                if self.speculative_algorithm.is_eagle3()
                else self.model_config.hidden_size
            )
            batch.spec_info = EagleDraftInput.create_idle_input(
                device=self.device,
                hidden_size=hidden_size,
                dtype=self.model_config.dtype,
                topk=self.topk,
                capture_hidden_mode=CaptureHiddenMode.LAST,
            )

        batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
        batch.spec_info.num_tokens_for_logprob_per_batch = 1
        batch.spec_info.prepare_extend_after_decode(
            batch,
            self.speculative_num_steps,
        )
        batch.forward_mode = (
            ForwardMode.DRAFT_EXTEND
            if not batch.forward_mode.is_idle()
            else ForwardMode.IDLE
        )

1004
        batch.return_hidden_states = False
1005
        model_worker_batch = batch.get_model_worker_batch()
1006
        assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
1007
1008
1009
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
1010
1011
1012
1013
        if forward_batch.seq_lens_cpu is not None:
            forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
        else:
            forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
1014
1015

        # Run
1016
        can_cuda_graph = (
1017
            self.cuda_graph_runner_for_draft_extend
1018
1019
1020
1021
1022
1023
            and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
        )
        if can_cuda_graph:
            logits_output = self.cuda_graph_runner_for_draft_extend.replay(
                forward_batch
            )
1024
1025
1026
1027
1028
            forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
                logits_output.topk_p,
                logits_output.topk_index,
            )
            forward_batch.spec_info.hidden_states = logits_output.hidden_states
1029
        else:
Cheng Wan's avatar
Cheng Wan committed
1030
            forward_batch.can_run_dp_cuda_graph = False
1031
1032
1033
1034
            if not forward_batch.forward_mode.is_idle():
                self.draft_model_runner.attn_backend.init_forward_metadata(
                    forward_batch
                )
Cheng Wan's avatar
Cheng Wan committed
1035
1036
            logits_output, _ = self.draft_model_runner.forward(
                forward_batch, skip_attn_backend_init=True
1037
            )
1038
            self.capture_for_decode(logits_output, forward_batch.spec_info)
1039

1040
        self._detect_nan_if_needed(logits_output)
1041

1042
1043
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
1044
1045
1046
        batch.forward_mode = (
            ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
        )
1047
        batch.seq_lens = seq_lens_backup
1048
1049
1050
        batch.req_pool_indices = req_pool_indices_backup
        batch.spec_info.accept_length = accept_length_backup
        batch.return_logprob = return_logprob_backup
1051

Lianmin Zheng's avatar
Lianmin Zheng committed
1052
    def capture_for_decode(
1053
        self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Lianmin Zheng's avatar
Lianmin Zheng committed
1054
    ):
Yineng Zhang's avatar
Yineng Zhang committed
1055
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
1056
1057
1058
1059
        draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
        draft_input.hidden_states = logits_output.hidden_states

    def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
1060
        if self.enable_nan_detection:
1061
1062
            logits = logits_output.next_token_logits
            if torch.any(torch.isnan(logits)):
1063
                logger.error("Detected errors during sampling! NaN in the logits.")
1064
                raise ValueError("Detected errors during sampling! NaN in the logits.")
1065
1066
1067
1068
1069
1070
1071
1072
1073


def load_token_map(token_map_path: str) -> List[int]:
    if not os.path.exists(token_map_path):
        cache_dir = snapshot_download(
            os.path.dirname(token_map_path),
            ignore_patterns=["*.bin", "*.safetensors"],
        )
        token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
1074
    hot_token_id = torch.load(token_map_path, weights_only=True)
1075
    return torch.tensor(hot_token_id, dtype=torch.int64)
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094


@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
    req_to_token: torch.Tensor,
    req_pool_indices: torch.Tensor,
    seq_lens,
    speculative_num_steps: int,
):
    prefix_lens = seq_lens
    seq_lens = prefix_lens + speculative_num_steps
    last_loc = get_last_loc(
        req_to_token,
        req_pool_indices,
        prefix_lens,
    )
    return prefix_lens, seq_lens, last_loc


1095
1096
1097
# Disable torch.compile for this function because it will be
# even slower.
# @torch.compile(dynamic=True)
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
def get_last_loc_large_page_size_large_top_k(
    req_to_token: torch.Tensor,
    req_pool_indices: torch.Tensor,
    seq_lens: torch.Tensor,
    speculative_num_steps: int,
    topk: int,
    page_size: int,
):
    prefix_lens = seq_lens
    last_page_lens = prefix_lens % page_size
    num_new_pages_per_topk = (
        last_page_lens + speculative_num_steps + page_size - 1
    ) // page_size
    seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
        page_size * topk
    )
    extend_lens = seq_lens - prefix_lens
    last_loc = get_last_loc(
        req_to_token,
        req_pool_indices,
        prefix_lens,
    )

    return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens