eagle_worker.py 37.8 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
import logging
2
import os
Yineng Zhang's avatar
Yineng Zhang committed
3
import time
4
from contextlib import contextmanager
5
from typing import List, Optional, Tuple
6
7

import torch
8
from huggingface_hub import snapshot_download
9

10
11
12
13
14
15
from sglang.srt.distributed import (
    GroupCoordinator,
    get_tensor_model_parallel_world_size,
    get_tp_group,
    patch_tensor_parallel_group,
)
16
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
17
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
18
19
20
21
22
from sglang.srt.managers.schedule_batch import (
    ScheduleBatch,
    get_last_loc,
    global_server_args_dict,
)
23
24
25
26
27
28
29
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.server_args import ServerArgs
30
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
Yineng Zhang's avatar
Yineng Zhang committed
31
32
33
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
34
35
36
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
    EAGLEDraftExtendCudaGraphRunner,
)
Yineng Zhang's avatar
Yineng Zhang committed
37
38
39
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
40
    EagleVerifyOutput,
Yineng Zhang's avatar
Yineng Zhang committed
41
    assign_draft_cache_locs,
42
    fast_topk,
43
    generate_token_bitmask,
Yineng Zhang's avatar
Yineng Zhang committed
44
45
    select_top_k_tokens,
)
James Liu's avatar
James Liu committed
46
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
47
48
49
50
51
52
from sglang.srt.utils import (
    empty_context,
    get_available_gpu_memory,
    is_cuda,
    next_power_of_2,
)
53

54
if is_cuda():
55
    from sgl_kernel import segment_packbits
Yineng Zhang's avatar
Yineng Zhang committed
56
57

logger = logging.getLogger(__name__)
58
59


60
61
62
63
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
    # Draft model doesn't use dp and has its own tp group.
    # We disable mscclpp now because it doesn't support 2 comm groups.
64
    with patch_tensor_parallel_group(tp_group):
65
66
67
        yield


68
69
70
71
72
73
74
75
class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
Cheng Wan's avatar
Cheng Wan committed
76
        moe_ep_rank: int,
77
78
79
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
80
81
82
83
        # Parse arguments
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
84
        self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
85
86
87
88
        self.enable_nan_detection = server_args.enable_nan_detection
        self.gpu_id = gpu_id
        self.device = server_args.device
        self.target_worker = target_worker
89
        self.page_size = server_args.page_size
James Liu's avatar
James Liu committed
90
91
92
        self.speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
93
        self.padded_static_len = -1
94

95
96
97
        # Override context length with target model's context length
        server_args.context_length = target_worker.model_runner.model_config.context_len

98
        # Do not capture cuda graph in `super().__init__()`
99
        # It will be captured later.
100
101
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
102
103
104
105
106
        # Share the allocator with a target worker.
        # Draft and target worker own their own KV cache pools.
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )
107

108
        # Load hot token ids
James Liu's avatar
James Liu committed
109
110
111
112
113
114
115
        if self.speculative_algorithm.is_eagle3():
            if server_args.speculative_token_map is not None:
                logger.warning(
                    "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
                )
            self.hot_token_id = None
        elif server_args.speculative_token_map is not None:
116
            self.hot_token_id = load_token_map(server_args.speculative_token_map)
117
118
119
            server_args.json_model_override_args = (
                f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
            )
120
121
        else:
            self.hot_token_id = None
122

123
        # Init draft worker
124
125
        with empty_context():
            super().__init__(
126
                server_args=server_args,
127
128
                gpu_id=gpu_id,
                tp_rank=tp_rank,
129
                pp_rank=0,  # FIXME
130
                dp_rank=dp_rank,
Cheng Wan's avatar
Cheng Wan committed
131
                moe_ep_rank=moe_ep_rank,
132
                nccl_port=nccl_port,
133
134
135
136
                is_draft_worker=True,
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
Yineng Zhang's avatar
Yineng Zhang committed
137

138
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
James Liu's avatar
James Liu committed
139
140
141
142
143
144

        if self.speculative_algorithm.is_eagle3():
            # EAGLE3 models don't share lm_head
            self.draft_model_runner.model.set_embed(embed)

            # grab hot token ids
lukec's avatar
lukec committed
145
146
147
148
149
            if self.draft_model_runner.model.hot_token_id is not None:
                self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
                    embed.device
                )

James Liu's avatar
James Liu committed
150
151
152
153
154
155
156
157
        else:
            if self.hot_token_id is not None:
                head = head.clone()
                self.hot_token_id = self.hot_token_id.to(head.device)
                head.data = head.data[self.hot_token_id]

            # Share the embedding and lm_head
            self.draft_model_runner.model.set_embed_and_head(embed, head)
158

159
        # Init attention backend and cuda graphs
160
161
162
        self.draft_model_runner.server_args.disable_cuda_graph = (
            backup_disable_cuda_graph
        )
163
164
165
166
167
168
        self.draft_tp_context = (
            draft_tp_context if server_args.enable_dp_attention else empty_context
        )
        with self.draft_tp_context(self.draft_model_runner.tp_group):
            self.init_attention_backend()
            self.init_cuda_graphs()
169

170
171
172
173
174
175
        # Some dummy tensors
        self.num_new_pages_per_topk = torch.empty(
            (), dtype=torch.int64, device=self.device
        )
        self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)

176
    def init_attention_backend(self):
177
        # Create multi-step attn backends and cuda graph runners
178
179
180
181

        self.has_prefill_wrapper_verify = False
        self.draft_extend_attn_backend = None

182
        if self.server_args.attention_backend == "flashinfer":
183
184
            if not global_server_args_dict["use_mla_backend"]:
                from sglang.srt.layers.attention.flashinfer_backend import (
185
                    FlashInferAttnBackend,
186
187
                    FlashInferMultiStepDraftBackend,
                )
188

189
190
191
192
193
                self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
194
195
196
197
                self.draft_extend_attn_backend = FlashInferAttnBackend(
                    self.draft_model_runner,
                    skip_prefill=False,
                )
198
199
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
200
                    FlashInferMLAAttnBackend,
201
202
203
204
205
206
207
208
                    FlashInferMLAMultiStepDraftBackend,
                )

                self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
209
210
211
212
                self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
                    self.draft_model_runner,
                    skip_prefill=False,
                )
213
            self.has_prefill_wrapper_verify = True
214
        elif self.server_args.attention_backend == "triton":
215
            from sglang.srt.layers.attention.triton_backend import (
216
                TritonAttnBackend,
217
218
219
220
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
221
                self.draft_model_runner,
222
223
224
                self.topk,
                self.speculative_num_steps,
            )
225
226
227
228
            self.draft_extend_attn_backend = TritonAttnBackend(
                self.draft_model_runner,
                skip_prefill=False,
            )
229
230
        elif self.server_args.attention_backend == "fa3":
            from sglang.srt.layers.attention.flashattention_backend import (
231
                FlashAttentionBackend,
232
233
234
235
236
237
238
239
                FlashAttentionMultiStepBackend,
            )

            self.draft_attn_backend = FlashAttentionMultiStepBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
240
241
242
243
            self.draft_extend_attn_backend = FlashAttentionBackend(
                self.draft_model_runner,
                skip_prefill=False,
            )
244
245
246
247
248
249
250
251
252
253
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import (
                FlashMLAMultiStepDraftBackend,
            )

            self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
254
255
        else:
            raise ValueError(
256
                f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
257
            )
258

259
        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
Yineng Zhang's avatar
Yineng Zhang committed
260
261

    def init_cuda_graphs(self):
262
        """Capture cuda graphs."""
Yineng Zhang's avatar
Yineng Zhang committed
263
        self.cuda_graph_runner = None
264
        self.cuda_graph_runner_for_draft_extend = None
Yineng Zhang's avatar
Yineng Zhang committed
265
266
267
268

        if self.server_args.disable_cuda_graph:
            return

269
        # Capture draft
270
        tic = time.perf_counter()
271
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
272
        logger.info(
273
            f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
274
        )
Yineng Zhang's avatar
Yineng Zhang committed
275
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
276
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
277
        logger.info(
278
            f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
279
        )
280

281
282
        # Capture extend
        if self.draft_extend_attn_backend:
283
284
285
286
287
288
289
290
291
292
            tic = time.perf_counter()
            before_mem = get_available_gpu_memory(self.device, self.gpu_id)
            logger.info(
                f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
            )
            self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
                self
            )
            after_mem = get_available_gpu_memory(self.device, self.gpu_id)
            logger.info(
293
                f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
294
            )
295

296
297
298
299
300
301
    @property
    def draft_model_runner(self):
        return self.model_runner

    def forward_batch_speculative_generation(
        self, batch: ScheduleBatch
Cheng Wan's avatar
Cheng Wan committed
302
    ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
303
304
        """Run speculative decoding forward.

305
306
        NOTE: Many states of batch is modified as you go through. It is not guaranteed that
        the final output batch have the same state as the input.
307
308
309
310

        Args:
            batch: The batch to run forward. The state of the batch is modified as it runs.
        Returns:
311
312
            A tuple of the final logit output of the target model, next tokens accepted,
            the batch id (used for overlap schedule), and number of accepted tokens.
313
        """
314
315
316
317
318
319
320
321
322
323
        if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
            logits_output, next_token_ids, bid, seq_lens_cpu = (
                self.forward_target_extend(batch)
            )
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                self.forward_draft_extend(
                    batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
                )
            return logits_output, next_token_ids, bid, 0, False
        else:
324
            with self.draft_tp_context(self.draft_model_runner.tp_group):
325
                spec_info = self.draft(batch)
326
327
            logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
                self.verify(batch, spec_info)
328
            )
329

Cheng Wan's avatar
Cheng Wan committed
330
331
332
333
334
335
336
337
338
339
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                # NOTE: We should use `check_forward_draft_extend_after_decode`
                # when DP attention is enabled, but it is slow. Skip it for now.
                if (
                    self.server_args.enable_dp_attention
                    or batch.spec_info.verified_id.shape[0] > 0
                ):
                    # decode is not finished
                    self.forward_draft_extend_after_decode(batch)

Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
            return (
                logits_output,
342
343
344
                verify_output.verified_id,
                model_worker_batch.bid,
                sum(verify_output.accept_length_per_req_cpu),
345
                can_run_cuda_graph,
Lianmin Zheng's avatar
Lianmin Zheng committed
346
            )
347

348
    def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
Cheng Wan's avatar
Cheng Wan committed
349
        local_need_forward = batch.spec_info.verified_id.shape[0] > 0
350
        if not self.server_args.enable_dp_attention:
351
            return local_need_forward
352
353
354
355
356
357
358
359
360
361
362
363

        global_need_forward = torch.tensor(
            [
                (local_need_forward),
            ],
            dtype=torch.int64,
        )
        torch.distributed.all_reduce(
            global_need_forward, group=get_tp_group().cpu_group
        )
        global_need_forward_cnt = global_need_forward[0].item()
        need_forward = global_need_forward_cnt > 0
364
        return need_forward
365
366
367

    def forward_target_extend(
        self, batch: ScheduleBatch
Cheng Wan's avatar
Cheng Wan committed
368
    ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        """Run the target extend.

        Args:
            batch: The batch to run. States could be modified.

        Returns:
            logits_output: The output of logits. It will contain the full hidden states.
            next_token_ids: Next token ids generated.
            bid: The model batch ID. Used for overlap schedule.
        """
        # Forward with the target model and get hidden states.
        # We need the full hidden states to prefill the KV cache of the draft model.
        model_worker_batch = batch.get_model_worker_batch()
        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
383
        logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
384
385
            model_worker_batch
        )
386
387
388
389
390
391
        return (
            logits_output,
            next_token_ids,
            model_worker_batch.bid,
            model_worker_batch.seq_lens_cpu,
        )
392

393
    def _draft_preprocess_decode(self, batch: ScheduleBatch):
Yineng Zhang's avatar
Yineng Zhang committed
394
395
396
397
        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

398
399
400
401
402
403
404
        # Accumulate penalty
        if batch.sampling_info.penalizer_orchestrator.is_required:
            # This is a relaxed version of penalties for speculative decoding.
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                spec_info.verified_id.to(torch.int64)
            )

Yineng Zhang's avatar
Yineng Zhang committed
405
        # Allocate cache locations
406
407
408
        # Layout of the out_cache_loc
        # [       topk 0         ] [       topk 1         ]
        # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
409
410
        if self.page_size == 1:
            out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
411
                num_seqs * self.speculative_num_steps * self.topk, backup_state=True
412
413
414
            )
        else:
            if self.topk == 1:
415
416
417
418
419
420
                prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
                    batch.req_to_token_pool.req_to_token,
                    batch.req_pool_indices,
                    batch.seq_lens,
                    self.speculative_num_steps,
                )
421
422
423
424
425
426
427
428
429
430
431
432
                extend_num_tokens = num_seqs * self.speculative_num_steps
            else:
                # In this case, the last partial page needs to be duplicated.
                # KV cache layout in batch.req_to_token_pool.req_to_token:
                #
                # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
                #    prefix     top-k = 0    tok-k = 1    top-k = 2
                #
                #  "-" means prefix tokens
                #  "x" means speculative draft tokens
                #  "." means padded tokens

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                # TODO(lmzheng): The current implementation is still a fake support
                # for page size > 1. In the `assign_draft_cache_locs` below,
                # we directly move the indices instead of the real kv cache.
                # This only works when the kernel backend runs with page size = 1.
                # If the kernel backend runs with page size > 1, we need to
                # duplicate the real KV cache. The overhead of duplicating KV
                # cache seems okay because the draft KV cache only has one layer.
                # see a related copy operation in MHATokenToKVPool::move_kv_cache.

                (
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    self.num_new_pages_per_topk,
                    self.extend_lens,
                ) = get_last_loc_large_page_size_large_top_k(
                    batch.req_to_token_pool.req_to_token,
                    batch.req_pool_indices,
                    batch.seq_lens,
                    self.speculative_num_steps,
                    self.topk,
                    self.page_size,
455
                )
456
457
458
459

                # TODO(lmzheng): remove this device sync
                extend_num_tokens = torch.sum(self.extend_lens).item()

460
461
462
463
464
465
466
467
468
469
            out_cache_loc, token_to_kv_pool_state_backup = (
                batch.alloc_paged_token_slots_extend(
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    extend_num_tokens,
                    backup_state=True,
                )
            )

Yineng Zhang's avatar
Yineng Zhang committed
470
471
472
473
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
474
475
            self.extend_lens,
            self.num_new_pages_per_topk,
Yineng Zhang's avatar
Yineng Zhang committed
476
477
478
479
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
480
            self.page_size,
481
482
            next_power_of_2(num_seqs),
            next_power_of_2(self.speculative_num_steps),
Yineng Zhang's avatar
Yineng Zhang committed
483
        )
484
485
486
487
488
489
490

        if self.page_size > 1 and self.topk > 1:
            # Remove padded slots
            out_cache_loc = out_cache_loc[
                : num_seqs * self.topk * self.speculative_num_steps
            ]

Yineng Zhang's avatar
Yineng Zhang committed
491
492
        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
493
        batch.return_hidden_states = False
Yineng Zhang's avatar
Yineng Zhang committed
494
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
495
496
497
498
499
500
        self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

    def _draft_preprocess_idle(self, batch: ScheduleBatch):
        batch.spec_info = EagleDraftInput.create_idle_input(
            device=self.device,
            hidden_size=self.model_config.hidden_size,
501
            dtype=self.model_config.dtype,
502
503
504
505
506
507
508
509
510
511
512
513
            topk=self.topk,
            capture_hidden_mode=CaptureHiddenMode.LAST,
        )

    def draft(self, batch: ScheduleBatch):
        # Parse args
        if batch.forward_mode.is_idle():
            self._draft_preprocess_idle(batch)
        else:
            self._draft_preprocess_decode(batch)

        spec_info = batch.spec_info
Cheng Wan's avatar
Cheng Wan committed
514
        assert isinstance(spec_info, EagleDraftInput)
515

Yineng Zhang's avatar
Yineng Zhang committed
516
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
Cheng Wan's avatar
Cheng Wan committed
517
518
        spec_info.num_tokens_per_batch = self.topk
        spec_info.num_tokens_for_logprob_per_batch = self.topk
519
        batch.return_hidden_states = False
520
521

        # Get forward batch
Yineng Zhang's avatar
Yineng Zhang committed
522
        model_worker_batch = batch.get_model_worker_batch()
523
        assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
524
525
526
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
Yineng Zhang's avatar
Yineng Zhang committed
527
528
529
530
531
532
533
534
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )
        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
Cheng Wan's avatar
Cheng Wan committed
535
            forward_batch.can_run_dp_cuda_graph = False
536
537
538
            if not forward_batch.forward_mode.is_idle():
                # Initialize attention backend
                self.draft_attn_backend.init_forward_metadata(forward_batch)
Yineng Zhang's avatar
Yineng Zhang committed
539
540
541
            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

542
543
544
545
546
547
        if batch.forward_mode.is_idle():
            return EagleVerifyInput.create_idle_input(
                self.topk,
                self.speculative_num_steps,
                self.speculative_num_draft_tokens,
            )
548

549
550
551
552
553
554
555
556
        (
            tree_mask,
            position,
            retrive_index,
            retrive_next_token,
            retrive_next_sibling,
            draft_tokens,
        ) = build_tree_kernel_efficient(
Yineng Zhang's avatar
Yineng Zhang committed
557
558
559
560
561
562
563
564
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
565
            self.speculative_num_draft_tokens,
Yineng Zhang's avatar
Yineng Zhang committed
566
        )
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582

        return EagleVerifyInput(
            draft_token=draft_tokens,
            custom_mask=tree_mask,
            positions=position,
            retrive_index=retrive_index,
            retrive_next_token=retrive_next_token,
            retrive_next_sibling=retrive_next_sibling,
            retrive_cum_len=None,
            spec_steps=self.speculative_num_steps,
            topk=self.topk,
            draft_token_num=self.server_args.speculative_num_draft_tokens,
            capture_hidden_mode=CaptureHiddenMode.FULL,
            seq_lens_sum=forward_batch.seq_lens_sum,
            seq_lens_cpu=forward_batch.seq_lens_cpu,
        )
Yineng Zhang's avatar
Yineng Zhang committed
583
584
585
586

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
Cheng Wan's avatar
Cheng Wan committed
587
        assert isinstance(spec_info, EagleDraftInput)
Yineng Zhang's avatar
Yineng Zhang committed
588
589
590
591
592
593
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )
594
595
        if self.hot_token_id is not None:
            topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
596

597
598
599
600
601
602
603
        out_cache_loc = out_cache_loc.reshape(
            forward_batch.batch_size, self.topk, self.speculative_num_steps
        )
        out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
            self.speculative_num_steps, -1
        )

Yineng Zhang's avatar
Yineng Zhang committed
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

619
            # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
620
621
622
            if i == self.speculative_num_steps - 1:
                break

Yineng Zhang's avatar
Yineng Zhang committed
623
624
            # Set inputs
            forward_batch.input_ids = input_ids
625
            forward_batch.out_cache_loc = out_cache_loc[i]
Yineng Zhang's avatar
Yineng Zhang committed
626
627
628
629
630
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
Cheng Wan's avatar
Cheng Wan committed
631
632
            logits_output, _ = self.draft_model_runner.forward(
                forward_batch, skip_attn_backend_init=True
Yineng Zhang's avatar
Yineng Zhang committed
633
            )
634
            self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
635
636
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
637
638
            if self.hot_token_id is not None:
                topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
639
640
641
642
643
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
644
        spec_info.prepare_for_verify(batch, self.page_size)
645
        batch.return_hidden_states = False
646
647
648
649
650
        batch.forward_mode = (
            ForwardMode.TARGET_VERIFY
            if not batch.forward_mode.is_idle()
            else ForwardMode.IDLE
        )
Yineng Zhang's avatar
Yineng Zhang committed
651
        batch.spec_info = spec_info
Cheng Wan's avatar
Cheng Wan committed
652

653
654
655
        model_worker_batch = batch.get_model_worker_batch(
            seq_lens_cpu_cache=spec_info.seq_lens_cpu
        )
656
        assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
657
658
659
660
661
662
663
664
665

        if batch.has_grammar:
            retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
            retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
            draft_tokens_cpu = spec_info.draft_token.view(
                spec_info.retrive_next_token.shape
            ).cpu()

        # Forward
666
667
668
669
        logits_output, _, can_run_cuda_graph = (
            self.target_worker.forward_batch_generation(
                model_worker_batch, skip_sample=True
            )
670
        )
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687

        vocab_mask = None
        if batch.has_grammar:
            # Generate the logit mask for structured output.
            # Overlap the CPU operations for bitmask generation with the forward pass.
            vocab_mask = generate_token_bitmask(
                batch.reqs,
                spec_info,
                retrieve_next_token_cpu,
                retrieve_next_sibling_cpu,
                draft_tokens_cpu,
                batch.sampling_info.vocab_size,
            )

            if vocab_mask is not None:
                assert spec_info.grammar is not None
                vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
688
                # NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
689
690
691
                # and will be applied to produce wrong results
                batch.sampling_info.vocab_mask = None

692
        self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
693
        spec_info.hidden_states = logits_output.hidden_states
694
        res: EagleVerifyOutput = spec_info.verify(
695
696
697
698
            batch,
            logits_output,
            self.token_to_kv_pool_allocator,
            self.page_size,
699
            vocab_mask,
700
701
702
        )

        # Post process based on verified outputs.
703
        # Pick indices that we care (accepted)
704
        logits_output.next_token_logits = logits_output.next_token_logits[
705
            res.accepted_indices
706
        ]
707
        logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
708

709
710
711
        if batch.return_logprob:
            self.add_logprob_values(batch, res, logits_output)

712
        # Prepare the batch for the next draft forwards.
713
714
715
        batch.forward_mode = (
            ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
        )
716
        batch.spec_info = res.draft_input
717

718
        return logits_output, res, model_worker_batch, can_run_cuda_graph
719

720
721
722
723
724
725
726
727
728
729
    def add_logprob_values(
        self,
        batch: ScheduleBatch,
        res: EagleVerifyOutput,
        logits_output: LogitsProcessorOutput,
    ):
        # Extract args
        logits_output = res.logits_output
        top_logprobs_nums = batch.top_logprobs_nums
        token_ids_logprobs = batch.token_ids_logprobs
730
731
732
733
734
735
736
737
        accepted_indices = res.accepted_indices
        assert len(accepted_indices) == len(logits_output.next_token_logits)
        temperatures = batch.sampling_info.temperatures
        num_draft_tokens = batch.spec_info.draft_token_num
        # acceptance indices are the indices in a "flattened" batch.
        # dividing it to num_draft_tokens will yield the actual batch index.
        temperatures = temperatures[accepted_indices // num_draft_tokens]

738
        logprobs = torch.nn.functional.log_softmax(
739
            logits_output.next_token_logits / temperatures, dim=-1
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
        )
        batch_next_token_ids = res.verified_id
        num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]

        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)

        # Extract logprobs
        if any(x > 0 for x in top_logprobs_nums):
            (
                logits_output.next_token_top_logprobs_val,
                logits_output.next_token_top_logprobs_idx,
            ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)

        if any(x is not None for x in token_ids_logprobs):
            (
                logits_output.next_token_token_ids_logprobs_val,
                logits_output.next_token_token_ids_logprobs_idx,
            ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)

        logits_output.next_token_logprobs = logprobs[
            torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
            batch_next_token_ids,
        ]

770
        # Add output logprobs to the request
771
772
773
        pt = 0
        next_token_logprobs = logits_output.next_token_logprobs.tolist()
        verified_ids = batch_next_token_ids.tolist()
774
        for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
775
776
777
778
779
780
781
782
783
784
785
786
787
            for _ in range(num_tokens):
                if req.return_logprob:
                    req.output_token_logprobs_val.append(next_token_logprobs[pt])
                    req.output_token_logprobs_idx.append(verified_ids[pt])
                    if req.top_logprobs_num > 0:
                        req.output_top_logprobs_val.append(
                            res.logits_output.next_token_top_logprobs_val[pt]
                        )
                        req.output_top_logprobs_idx.append(
                            res.logits_output.next_token_top_logprobs_idx[pt]
                        )
                pt += 1

788
789
790
791
    def forward_draft_extend(
        self,
        batch: ScheduleBatch,
        hidden_states: torch.Tensor,
Cheng Wan's avatar
Cheng Wan committed
792
793
        next_token_ids: torch.Tensor,
        seq_lens_cpu: Optional[torch.Tensor],
794
795
796
797
798
799
800
801
802
803
804
    ):
        """Run draft model extend. This API modifies the states of the batch.

        Args:
            batch: The batch to run.
            hidden_states: Hidden states from the target model forward
            next_token_ids: Next token ids generated from the target forward.
        """
        batch.spec_info = EagleDraftInput(
            hidden_states=hidden_states,
            verified_id=next_token_ids,
Cheng Wan's avatar
Cheng Wan committed
805
806
            num_tokens_per_batch=1,
            num_tokens_for_logprob_per_batch=1,
807
        )
808
        batch.return_hidden_states = False
Yineng Zhang's avatar
Yineng Zhang committed
809
810
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
811
812
813
        model_worker_batch = batch.get_model_worker_batch(
            seq_lens_cpu_cache=seq_lens_cpu
        )
814
815
816
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
817
        forward_batch.return_logprob = False
818
        logits_output, _ = self.draft_model_runner.forward(forward_batch)
819
820
821
822
        self._detect_nan_if_needed(logits_output)
        assert isinstance(forward_batch.spec_info, EagleDraftInput)
        assert forward_batch.spec_info is batch.spec_info
        self.capture_for_decode(logits_output, forward_batch.spec_info)
823

824
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
Cheng Wan's avatar
Cheng Wan committed
825
        assert isinstance(batch.spec_info, EagleDraftInput)
826
        # Backup fields that will be modified in-place
827
828
829
830
        seq_lens_backup = batch.seq_lens.clone()
        req_pool_indices_backup = batch.req_pool_indices
        accept_length_backup = batch.spec_info.accept_length
        return_logprob_backup = batch.return_logprob
Cheng Wan's avatar
Cheng Wan committed
831

832
        input_is_idle = batch.forward_mode.is_idle()
Cheng Wan's avatar
Cheng Wan committed
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861

        if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
            batch = batch.copy()
            batch.prepare_for_idle()
            hidden_size = (
                self.model_config.hidden_size * 3
                if self.speculative_algorithm.is_eagle3()
                else self.model_config.hidden_size
            )
            batch.spec_info = EagleDraftInput.create_idle_input(
                device=self.device,
                hidden_size=hidden_size,
                dtype=self.model_config.dtype,
                topk=self.topk,
                capture_hidden_mode=CaptureHiddenMode.LAST,
            )

        batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
        batch.spec_info.num_tokens_for_logprob_per_batch = 1
        batch.spec_info.prepare_extend_after_decode(
            batch,
            self.speculative_num_steps,
        )
        batch.forward_mode = (
            ForwardMode.DRAFT_EXTEND
            if not batch.forward_mode.is_idle()
            else ForwardMode.IDLE
        )

862
        batch.return_hidden_states = False
863
        model_worker_batch = batch.get_model_worker_batch()
864
        assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
865
866
867
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
868
869
870
871
        if forward_batch.seq_lens_cpu is not None:
            forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
        else:
            forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
872
873

        # Run
874
        can_cuda_graph = (
875
            self.cuda_graph_runner_for_draft_extend
876
877
878
879
880
881
            and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
        )
        if can_cuda_graph:
            logits_output = self.cuda_graph_runner_for_draft_extend.replay(
                forward_batch
            )
882
883
884
885
886
            forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
                logits_output.topk_p,
                logits_output.topk_index,
            )
            forward_batch.spec_info.hidden_states = logits_output.hidden_states
887
        else:
Cheng Wan's avatar
Cheng Wan committed
888
            forward_batch.can_run_dp_cuda_graph = False
889
890
891
892
            if not forward_batch.forward_mode.is_idle():
                self.draft_model_runner.attn_backend.init_forward_metadata(
                    forward_batch
                )
Cheng Wan's avatar
Cheng Wan committed
893
894
            logits_output, _ = self.draft_model_runner.forward(
                forward_batch, skip_attn_backend_init=True
895
            )
896
            self.capture_for_decode(logits_output, forward_batch.spec_info)
897

898
        self._detect_nan_if_needed(logits_output)
899

900
901
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
902
903
904
        batch.forward_mode = (
            ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
        )
905
        batch.seq_lens = seq_lens_backup
906
907
908
        batch.req_pool_indices = req_pool_indices_backup
        batch.spec_info.accept_length = accept_length_backup
        batch.return_logprob = return_logprob_backup
909

Lianmin Zheng's avatar
Lianmin Zheng committed
910
    def capture_for_decode(
911
        self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Lianmin Zheng's avatar
Lianmin Zheng committed
912
    ):
Yineng Zhang's avatar
Yineng Zhang committed
913
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
914
915
916
917
        draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
        draft_input.hidden_states = logits_output.hidden_states

    def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
918
        if self.enable_nan_detection:
919
920
            logits = logits_output.next_token_logits
            if torch.any(torch.isnan(logits)):
921
                logger.error("Detected errors during sampling! NaN in the logits.")
922
                raise ValueError("Detected errors during sampling! NaN in the logits.")
923
924
925
926
927
928
929
930
931


def load_token_map(token_map_path: str) -> List[int]:
    if not os.path.exists(token_map_path):
        cache_dir = snapshot_download(
            os.path.dirname(token_map_path),
            ignore_patterns=["*.bin", "*.safetensors"],
        )
        token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
932
    hot_token_id = torch.load(token_map_path, weights_only=True)
933
    return torch.tensor(hot_token_id, dtype=torch.int64)
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977


@torch.compile(dynamic=True)
def get_last_loc_large_page_size_top_k_1(
    req_to_token: torch.Tensor,
    req_pool_indices: torch.Tensor,
    seq_lens,
    speculative_num_steps: int,
):
    prefix_lens = seq_lens
    seq_lens = prefix_lens + speculative_num_steps
    last_loc = get_last_loc(
        req_to_token,
        req_pool_indices,
        prefix_lens,
    )
    return prefix_lens, seq_lens, last_loc


@torch.compile(dynamic=True)
def get_last_loc_large_page_size_large_top_k(
    req_to_token: torch.Tensor,
    req_pool_indices: torch.Tensor,
    seq_lens: torch.Tensor,
    speculative_num_steps: int,
    topk: int,
    page_size: int,
):
    prefix_lens = seq_lens
    last_page_lens = prefix_lens % page_size
    num_new_pages_per_topk = (
        last_page_lens + speculative_num_steps + page_size - 1
    ) // page_size
    seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
        page_size * topk
    )
    extend_lens = seq_lens - prefix_lens
    last_loc = get_last_loc(
        req_to_token,
        req_pool_indices,
        prefix_lens,
    )

    return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens