eagle_worker.py 26.4 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
import logging
2
import os
Yineng Zhang's avatar
Yineng Zhang committed
3
import time
4
from contextlib import contextmanager
5
from typing import List, Optional, Tuple
6
7

import torch
8
from huggingface_hub import snapshot_download
9

10
11
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
12
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
15
16
17
18
from sglang.srt.managers.schedule_batch import (
    ScheduleBatch,
    get_last_loc,
    global_server_args_dict,
)
19
20
21
22
23
24
25
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.server_args import ServerArgs
Yineng Zhang's avatar
Yineng Zhang committed
26
27
28
29
30
31
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
32
    EagleVerifyOutput,
Yineng Zhang's avatar
Yineng Zhang committed
33
34
35
    assign_draft_cache_locs,
    select_top_k_tokens,
)
James Liu's avatar
James Liu committed
36
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
37
38
39
40
41
42
from sglang.srt.utils import (
    empty_context,
    fast_topk,
    get_available_gpu_memory,
    is_cuda_available,
)
43
44
45

if is_cuda_available():
    from sgl_kernel import segment_packbits
Yineng Zhang's avatar
Yineng Zhang committed
46
47

logger = logging.getLogger(__name__)
48
49


50
51
52
53
54
55
56
57
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
    # Draft model doesn't use dp and has its own tp group.
    # We disable mscclpp now because it doesn't support 2 comm groups.
    with disable_dp_size(), patch_tensor_parallel_group(tp_group):
        yield


58
59
60
61
62
63
64
65
66
67
68
class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
69
70
71
72
73
74
75
76
77
        # Parse arguments
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.padded_static_len = self.speculative_num_steps + 1
        self.enable_nan_detection = server_args.enable_nan_detection
        self.gpu_id = gpu_id
        self.device = server_args.device
        self.target_worker = target_worker
78
        self.page_size = server_args.page_size
James Liu's avatar
James Liu committed
79
80
81
        self.speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
82

83
84
85
        # Override context length with target model's context length
        server_args.context_length = target_worker.model_runner.model_config.context_len

86
        # Do not capture cuda graph in `super().__init__()`
87
        # It will be captured later.
88
89
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
90
91
92
93
94
        # Share the allocator with a target worker.
        # Draft and target worker own their own KV cache pools.
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )
95

96
        # Load hot token ids
James Liu's avatar
James Liu committed
97
98
99
100
101
102
103
        if self.speculative_algorithm.is_eagle3():
            if server_args.speculative_token_map is not None:
                logger.warning(
                    "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
                )
            self.hot_token_id = None
        elif server_args.speculative_token_map is not None:
104
            self.hot_token_id = load_token_map(server_args.speculative_token_map)
105
106
107
            server_args.json_model_override_args = (
                f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
            )
108
109
        else:
            self.hot_token_id = None
110

111
        # Init draft worker
112
113
114
115
116
117
118
119
120
121
122
        with empty_context():
            super().__init__(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=nccl_port,
                dp_rank=dp_rank,
                is_draft_worker=True,
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
Yineng Zhang's avatar
Yineng Zhang committed
123

124
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
James Liu's avatar
James Liu committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

        if self.speculative_algorithm.is_eagle3():
            # EAGLE3 models don't share lm_head
            self.draft_model_runner.model.set_embed(embed)

            # grab hot token ids
            self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
                embed.device
            )
        else:
            if self.hot_token_id is not None:
                head = head.clone()
                self.hot_token_id = self.hot_token_id.to(head.device)
                head.data = head.data[self.hot_token_id]

            # Share the embedding and lm_head
            self.draft_model_runner.model.set_embed_and_head(embed, head)
142
143

        # Init attention backend and cuda graphs
144
145
146
        self.draft_model_runner.server_args.disable_cuda_graph = (
            backup_disable_cuda_graph
        )
147
148
149
150
151
152
        self.draft_tp_context = (
            draft_tp_context if server_args.enable_dp_attention else empty_context
        )
        with self.draft_tp_context(self.draft_model_runner.tp_group):
            self.init_attention_backend()
            self.init_cuda_graphs()
153
154

    def init_attention_backend(self):
Yineng Zhang's avatar
Yineng Zhang committed
155
        # Create multi-step attn backends and cuda graph runners
156
        if self.server_args.attention_backend == "flashinfer":
157
158
159
160
            if not global_server_args_dict["use_mla_backend"]:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferMultiStepDraftBackend,
                )
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
                self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAMultiStepDraftBackend,
                )

                self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
177
178
179
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = True
180
        elif self.server_args.attention_backend == "triton":
181
182
183
184
185
            from sglang.srt.layers.attention.triton_backend import (
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
186
                self.draft_model_runner,
187
188
189
                self.topk,
                self.speculative_num_steps,
            )
190
191
192
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
193
194
195
196
197
198
199
200
201
202
203
204
205
        elif self.server_args.attention_backend == "fa3":
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionMultiStepBackend,
            )

            self.draft_attn_backend = FlashAttentionMultiStepBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
206
207
        else:
            raise ValueError(
208
                f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
209
            )
210

211
        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
Yineng Zhang's avatar
Yineng Zhang committed
212
213
214
215

    def init_cuda_graphs(self):
        """Capture cuda graphs."""
        self.cuda_graph_runner = None
216
        self.cuda_graph_runner_for_draft_extend = None
Yineng Zhang's avatar
Yineng Zhang committed
217
218
219
220

        if self.server_args.disable_cuda_graph:
            return

221
        # Capture draft
Yineng Zhang's avatar
Yineng Zhang committed
222
        tic = time.time()
223
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
224
        logger.info(
225
            f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
226
        )
Yineng Zhang's avatar
Yineng Zhang committed
227
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
228
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
229
        logger.info(
230
            f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
231
        )
232

233
234
235
236
        # Capture extend
        if self.draft_extend_attn_backend:
            raise NotImplementedError()

237
238
239
240
241
242
243
244
245
    @property
    def draft_model_runner(self):
        return self.model_runner

    def forward_batch_speculative_generation(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
        """Run speculative decoding forward.

246
247
        NOTE: Many states of batch is modified as you go through. It is not guaranteed that
        the final output batch have the same state as the input.
248
249
250
251
252
253
254

        Args:
            batch: The batch to run forward. The state of the batch is modified as it runs.
        Returns:
            A tuple of the final logit output of the target model, next tokens accepeted,
            the batch id (used for overlap schedule), and number of accepeted tokens.
        """
255
        if batch.forward_mode.is_decode():
256
            with self.draft_tp_context(self.draft_model_runner.tp_group):
257
                spec_info = self.draft(batch)
258
259
260
            logits_output, verify_output, model_worker_batch = self.verify(
                batch, spec_info
            )
261
262
263
264
265

            # If it is None, it means all requests are finished
            if batch.spec_info.verified_id is not None:
                with self.draft_tp_context(self.draft_model_runner.tp_group):
                    self.forward_draft_extend_after_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
266
267
            return (
                logits_output,
268
269
270
                verify_output.verified_id,
                model_worker_batch.bid,
                sum(verify_output.accept_length_per_req_cpu),
Lianmin Zheng's avatar
Lianmin Zheng committed
271
            )
272
273
274
275
276
277
278
279
280
281
        elif batch.forward_mode.is_idle():
            model_worker_batch = batch.get_model_worker_batch()
            logits_output, next_token_ids, _ = (
                self.target_worker.forward_batch_generation(
                    ForwardBatch.init_new(
                        model_worker_batch, self.target_worker.model_runner
                    )
                )
            )
            return logits_output, next_token_ids, model_worker_batch.bid, 0, False
282
        else:
283
            logits_output, next_token_ids, bid = self.forward_target_extend(batch)
284
285
286
287
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                self.forward_draft_extend(
                    batch, logits_output.hidden_states, next_token_ids
                )
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            return logits_output, next_token_ids, bid, 0

    def forward_target_extend(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int]:
        """Run the target extend.

        Args:
            batch: The batch to run. States could be modified.

        Returns:
            logits_output: The output of logits. It will contain the full hidden states.
            next_token_ids: Next token ids generated.
            bid: The model batch ID. Used for overlap schedule.
        """
        # Forward with the target model and get hidden states.
        # We need the full hidden states to prefill the KV cache of the draft model.
        model_worker_batch = batch.get_model_worker_batch()
        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
        logits_output, next_token_ids = self.target_worker.forward_batch_generation(
            model_worker_batch
        )
        return logits_output, next_token_ids, model_worker_batch.bid
311

Yineng Zhang's avatar
Yineng Zhang committed
312
313
314
315
316
    def draft(self, batch: ScheduleBatch):
        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

317
318
319
320
321
322
323
        # Accumulate penalty
        if batch.sampling_info.penalizer_orchestrator.is_required:
            # This is a relaxed version of penalties for speculative decoding.
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                spec_info.verified_id.to(torch.int64)
            )

Yineng Zhang's avatar
Yineng Zhang committed
324
        # Allocate cache locations
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
        if self.page_size == 1:
            out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
                num_seqs * self.topk * self.speculative_num_steps, backup_state=True
            )
        else:
            if self.topk == 1:
                prefix_lens = batch.seq_lens
                seq_lens = prefix_lens + self.speculative_num_steps
                extend_num_tokens = num_seqs * self.speculative_num_steps
            else:
                # In this case, the last partial page needs to be duplicated.
                # KV cache layout in batch.req_to_token_pool.req_to_token:
                #
                # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
                #    prefix     top-k = 0    tok-k = 1    top-k = 2
                #
                #  "-" means prefix tokens
                #  "x" means speculative draft tokens
                #  "." means padded tokens

                # TODO: fuse these ops
                prefix_lens = batch.seq_lens
                last_page_lens = prefix_lens % self.page_size
                num_new_pages = (
                    last_page_lens + self.speculative_num_steps + self.page_size - 1
                ) // self.page_size
                seq_lens = (
                    prefix_lens // self.page_size * self.page_size
                    + num_new_pages * (self.page_size * self.topk)
                )
                extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
                raise NotImplementedError(
                    "page_size > 1 and top_k > 1 are not supported."
                )
                # TODO: Support page_size > 1 and top_k > 1
                # 1. Duplicate the KV cache in the last partial page for all top-k segments
                # 2. Modify generate_draft_decode_kv_indices accordingly

            last_loc = get_last_loc(
                batch.req_to_token_pool.req_to_token,
                batch.req_pool_indices,
                prefix_lens,
            )
            out_cache_loc, token_to_kv_pool_state_backup = (
                batch.alloc_paged_token_slots_extend(
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    extend_num_tokens,
                    backup_state=True,
                )
            )

Yineng Zhang's avatar
Yineng Zhang committed
378
379
380
381
382
383
384
385
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
386
            self.page_size,
Yineng Zhang's avatar
Yineng Zhang committed
387
388
389
390
391
392
393
394
        )
        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)

        # Get forward batch
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
395
396
397
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
Yineng Zhang's avatar
Yineng Zhang committed
398
399
400
401
402
403
404
405
406
407
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )
        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
            # Initialize attention backend
            self.draft_attn_backend.init_forward_metadata(forward_batch)
408
409
410
            forward_batch = ForwardBatch.init_new(
                model_worker_batch, self.draft_model_runner
            )
Yineng Zhang's avatar
Yineng Zhang committed
411
412
413
            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

414
415
        self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

Yineng Zhang's avatar
Yineng Zhang committed
416
417
418
419
420
421
422
423
424
425
426
        ret = EagleVerifyInput.create(
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
            self.server_args.speculative_num_draft_tokens,
        )
427
        return ret
Yineng Zhang's avatar
Yineng Zhang committed
428
429
430
431
432
433
434
435
436
437

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )
438
439
        if self.hot_token_id is not None:
            topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

456
            # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
457
458
459
            if i == self.speculative_num_steps - 1:
                break

Yineng Zhang's avatar
Yineng Zhang committed
460
461
            # Set inputs
            forward_batch.input_ids = input_ids
462
            out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
Yineng Zhang's avatar
Yineng Zhang committed
463
            forward_batch.out_cache_loc = out_cache_loc[
464
465
                :, self.topk * i : self.topk * (i + 1)
            ].flatten()
Yineng Zhang's avatar
Yineng Zhang committed
466
467
468
469
470
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
471
            logits_output = self.draft_model_runner.model.forward(
Yineng Zhang's avatar
Yineng Zhang committed
472
473
                forward_batch.input_ids, forward_batch.positions, forward_batch
            )
474
            self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
475
476
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
477
478
            if self.hot_token_id is not None:
                topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
479
480
481
482
483
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
484
        spec_info.prepare_for_verify(batch, self.page_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
485
        batch.forward_mode = ForwardMode.TARGET_VERIFY
Yineng Zhang's avatar
Yineng Zhang committed
486
        batch.spec_info = spec_info
487
488
489
490
        model_worker_batch = batch.get_model_worker_batch()
        logits_output, _ = self.target_worker.forward_batch_generation(
            model_worker_batch, skip_sample=True
        )
491
        self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
492
        spec_info.hidden_states = logits_output.hidden_states
493
        res: EagleVerifyOutput = spec_info.verify(
494
495
496
497
            batch,
            logits_output,
            self.token_to_kv_pool_allocator,
            self.page_size,
498
499
500
501
502
        )

        # Post process based on verified outputs.
        # Pick indices that we care (accepeted)
        logits_output.next_token_logits = logits_output.next_token_logits[
503
            res.accepeted_indices
504
        ]
505
506
        logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]

507
        # Prepare the batch for the next draft forwards.
508
        batch.forward_mode = ForwardMode.DECODE
509
        batch.spec_info = res.draft_input
510

511
        if batch.return_logprob:
512
            self.add_logprob_values(batch, res, logits_output)
513

514
515
        return logits_output, res, model_worker_batch

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
    def add_logprob_values(
        self,
        batch: ScheduleBatch,
        res: EagleVerifyOutput,
        logits_output: LogitsProcessorOutput,
    ):
        # Extract args
        logits_output = res.logits_output
        top_logprobs_nums = batch.top_logprobs_nums
        token_ids_logprobs = batch.token_ids_logprobs
        logprobs = torch.nn.functional.log_softmax(
            logits_output.next_token_logits, dim=-1
        )
        batch_next_token_ids = res.verified_id
        num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]

        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)

        # Extract logprobs
        if any(x > 0 for x in top_logprobs_nums):
            (
                logits_output.next_token_top_logprobs_val,
                logits_output.next_token_top_logprobs_idx,
            ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)

        if any(x is not None for x in token_ids_logprobs):
            (
                logits_output.next_token_token_ids_logprobs_val,
                logits_output.next_token_token_ids_logprobs_idx,
            ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)

        logits_output.next_token_logprobs = logprobs[
            torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
            batch_next_token_ids,
        ]

558
        # Add output logprobs to the request
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        pt = 0
        next_token_logprobs = logits_output.next_token_logprobs.tolist()
        verified_ids = batch_next_token_ids.tolist()
        for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
            for _ in range(num_tokens):
                if req.return_logprob:
                    req.output_token_logprobs_val.append(next_token_logprobs[pt])
                    req.output_token_logprobs_idx.append(verified_ids[pt])
                    if req.top_logprobs_num > 0:
                        req.output_top_logprobs_val.append(
                            res.logits_output.next_token_top_logprobs_val[pt]
                        )
                        req.output_top_logprobs_idx.append(
                            res.logits_output.next_token_top_logprobs_idx[pt]
                        )
                pt += 1

576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    def forward_draft_extend(
        self,
        batch: ScheduleBatch,
        hidden_states: torch.Tensor,
        next_token_ids: List[int],
    ):
        """Run draft model extend. This API modifies the states of the batch.

        Args:
            batch: The batch to run.
            hidden_states: Hidden states from the target model forward
            next_token_ids: Next token ids generated from the target forward.
        """
        batch.spec_info = EagleDraftInput(
            hidden_states=hidden_states,
            verified_id=next_token_ids,
        )
Yineng Zhang's avatar
Yineng Zhang committed
593
594
595
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
596
597
598
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
599
        forward_batch.return_logprob = False
600
601
602
603
604
        logits_output = self.draft_model_runner.forward(forward_batch)
        self._detect_nan_if_needed(logits_output)
        assert isinstance(forward_batch.spec_info, EagleDraftInput)
        assert forward_batch.spec_info is batch.spec_info
        self.capture_for_decode(logits_output, forward_batch.spec_info)
605

Lianmin Zheng's avatar
Lianmin Zheng committed
606
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
607
608
609
610
611
612
613
        # Backup fileds that will be modified in-place
        seq_lens_backup = batch.seq_lens.clone()
        req_pool_indices_backup = batch.req_pool_indices
        accept_length_backup = batch.spec_info.accept_length
        return_logprob_backup = batch.return_logprob

        # Prepare metadata
614
        batch.forward_mode = ForwardMode.DRAFT_EXTEND
615
616
617
618
        batch.spec_info.prepare_extend_after_decode(
            batch,
            self.speculative_num_steps,
        )
619
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
620
        batch.return_logprob = False
621
        model_worker_batch = batch.get_model_worker_batch()
622
623
624
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
625
626

        # Run
627
        logits_output = self.draft_model_runner.forward(forward_batch)
628

629
630
        self._detect_nan_if_needed(logits_output)
        self.capture_for_decode(logits_output, forward_batch.spec_info)
631

632
633
634
635
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
        batch.forward_mode = ForwardMode.DECODE
        batch.seq_lens = seq_lens_backup
636
637
638
        batch.req_pool_indices = req_pool_indices_backup
        batch.spec_info.accept_length = accept_length_backup
        batch.return_logprob = return_logprob_backup
639

Lianmin Zheng's avatar
Lianmin Zheng committed
640
    def capture_for_decode(
641
        self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Lianmin Zheng's avatar
Lianmin Zheng committed
642
    ):
Yineng Zhang's avatar
Yineng Zhang committed
643
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
644
645
646
647
        draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
        draft_input.hidden_states = logits_output.hidden_states

    def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
648
        if self.enable_nan_detection:
649
650
            logits = logits_output.next_token_logits
            if torch.any(torch.isnan(logits)):
651
                logger.error("Detected errors during sampling! NaN in the logits.")
652
                raise ValueError("Detected errors during sampling! NaN in the logits.")
653
654
655
656
657
658
659
660
661


def load_token_map(token_map_path: str) -> List[int]:
    if not os.path.exists(token_map_path):
        cache_dir = snapshot_download(
            os.path.dirname(token_map_path),
            ignore_patterns=["*.bin", "*.safetensors"],
        )
        token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
662
    hot_token_id = torch.load(token_map_path, weights_only=True)
663
    return torch.tensor(hot_token_id, dtype=torch.int32)