eagle_worker.py 26.2 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
import logging
2
import os
Yineng Zhang's avatar
Yineng Zhang committed
3
import time
4
from contextlib import contextmanager
5
from typing import List, Optional, Tuple
6
7

import torch
8
from huggingface_hub import snapshot_download
9

10
11
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
12
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
15
16
17
18
from sglang.srt.managers.schedule_batch import (
    ScheduleBatch,
    get_last_loc,
    global_server_args_dict,
)
19
20
21
22
23
24
25
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.server_args import ServerArgs
Yineng Zhang's avatar
Yineng Zhang committed
26
27
28
29
30
31
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
32
    EagleVerifyOutput,
Yineng Zhang's avatar
Yineng Zhang committed
33
34
35
    assign_draft_cache_locs,
    select_top_k_tokens,
)
James Liu's avatar
James Liu committed
36
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
37
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
38

39
if is_cuda():
40
    from sgl_kernel import segment_packbits
Yineng Zhang's avatar
Yineng Zhang committed
41
42

logger = logging.getLogger(__name__)
43
44


45
46
47
48
49
50
51
52
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
    # Draft model doesn't use dp and has its own tp group.
    # We disable mscclpp now because it doesn't support 2 comm groups.
    with disable_dp_size(), patch_tensor_parallel_group(tp_group):
        yield


53
54
55
56
57
58
59
60
61
62
63
class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
64
65
66
67
68
69
70
71
72
        # Parse arguments
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.padded_static_len = self.speculative_num_steps + 1
        self.enable_nan_detection = server_args.enable_nan_detection
        self.gpu_id = gpu_id
        self.device = server_args.device
        self.target_worker = target_worker
73
        self.page_size = server_args.page_size
James Liu's avatar
James Liu committed
74
75
76
        self.speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
77

78
79
80
        # Override context length with target model's context length
        server_args.context_length = target_worker.model_runner.model_config.context_len

81
        # Do not capture cuda graph in `super().__init__()`
82
        # It will be captured later.
83
84
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
85
86
87
88
89
        # Share the allocator with a target worker.
        # Draft and target worker own their own KV cache pools.
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )
90

91
        # Load hot token ids
James Liu's avatar
James Liu committed
92
93
94
95
96
97
98
        if self.speculative_algorithm.is_eagle3():
            if server_args.speculative_token_map is not None:
                logger.warning(
                    "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
                )
            self.hot_token_id = None
        elif server_args.speculative_token_map is not None:
99
            self.hot_token_id = load_token_map(server_args.speculative_token_map)
100
101
102
            server_args.json_model_override_args = (
                f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
            )
103
104
        else:
            self.hot_token_id = None
105

106
        # Init draft worker
107
108
        with empty_context():
            super().__init__(
109
                server_args=server_args,
110
111
                gpu_id=gpu_id,
                tp_rank=tp_rank,
112
                pp_rank=0,  # FIXME
113
                dp_rank=dp_rank,
114
                nccl_port=nccl_port,
115
116
117
118
                is_draft_worker=True,
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
Yineng Zhang's avatar
Yineng Zhang committed
119

120
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
James Liu's avatar
James Liu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        if self.speculative_algorithm.is_eagle3():
            # EAGLE3 models don't share lm_head
            self.draft_model_runner.model.set_embed(embed)

            # grab hot token ids
            self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
                embed.device
            )
        else:
            if self.hot_token_id is not None:
                head = head.clone()
                self.hot_token_id = self.hot_token_id.to(head.device)
                head.data = head.data[self.hot_token_id]

            # Share the embedding and lm_head
            self.draft_model_runner.model.set_embed_and_head(embed, head)
138
139

        # Init attention backend and cuda graphs
140
141
142
        self.draft_model_runner.server_args.disable_cuda_graph = (
            backup_disable_cuda_graph
        )
143
144
145
146
147
148
        self.draft_tp_context = (
            draft_tp_context if server_args.enable_dp_attention else empty_context
        )
        with self.draft_tp_context(self.draft_model_runner.tp_group):
            self.init_attention_backend()
            self.init_cuda_graphs()
149
150

    def init_attention_backend(self):
Yineng Zhang's avatar
Yineng Zhang committed
151
        # Create multi-step attn backends and cuda graph runners
152
        if self.server_args.attention_backend == "flashinfer":
153
154
155
156
            if not global_server_args_dict["use_mla_backend"]:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferMultiStepDraftBackend,
                )
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAMultiStepDraftBackend,
                )

                self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
173
174
175
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = True
176
        elif self.server_args.attention_backend == "triton":
177
178
179
180
181
            from sglang.srt.layers.attention.triton_backend import (
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
182
                self.draft_model_runner,
183
184
185
                self.topk,
                self.speculative_num_steps,
            )
186
187
188
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
189
190
191
192
193
194
195
196
197
198
199
200
201
        elif self.server_args.attention_backend == "fa3":
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionMultiStepBackend,
            )

            self.draft_attn_backend = FlashAttentionMultiStepBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
202
203
        else:
            raise ValueError(
204
                f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
205
            )
206

207
        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
Yineng Zhang's avatar
Yineng Zhang committed
208
209
210
211

    def init_cuda_graphs(self):
        """Capture cuda graphs."""
        self.cuda_graph_runner = None
212
        self.cuda_graph_runner_for_draft_extend = None
Yineng Zhang's avatar
Yineng Zhang committed
213
214
215
216

        if self.server_args.disable_cuda_graph:
            return

217
        # Capture draft
Yineng Zhang's avatar
Yineng Zhang committed
218
        tic = time.time()
219
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
220
        logger.info(
221
            f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
222
        )
Yineng Zhang's avatar
Yineng Zhang committed
223
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
224
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
225
        logger.info(
226
            f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
227
        )
228

229
230
231
232
        # Capture extend
        if self.draft_extend_attn_backend:
            raise NotImplementedError()

233
234
235
236
237
238
239
240
241
    @property
    def draft_model_runner(self):
        return self.model_runner

    def forward_batch_speculative_generation(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
        """Run speculative decoding forward.

242
243
        NOTE: Many states of batch is modified as you go through. It is not guaranteed that
        the final output batch have the same state as the input.
244
245
246
247
248
249
250

        Args:
            batch: The batch to run forward. The state of the batch is modified as it runs.
        Returns:
            A tuple of the final logit output of the target model, next tokens accepeted,
            the batch id (used for overlap schedule), and number of accepeted tokens.
        """
251
        if batch.forward_mode.is_decode():
252
            with self.draft_tp_context(self.draft_model_runner.tp_group):
253
                spec_info = self.draft(batch)
254
255
256
            logits_output, verify_output, model_worker_batch = self.verify(
                batch, spec_info
            )
257
258
259
260
261

            # If it is None, it means all requests are finished
            if batch.spec_info.verified_id is not None:
                with self.draft_tp_context(self.draft_model_runner.tp_group):
                    self.forward_draft_extend_after_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
            return (
                logits_output,
264
265
266
                verify_output.verified_id,
                model_worker_batch.bid,
                sum(verify_output.accept_length_per_req_cpu),
Lianmin Zheng's avatar
Lianmin Zheng committed
267
            )
268
269
        elif batch.forward_mode.is_idle():
            model_worker_batch = batch.get_model_worker_batch()
270
271
            logits_output, next_token_ids = self.target_worker.forward_batch_generation(
                model_worker_batch
272
            )
273
274

            return logits_output, next_token_ids, model_worker_batch.bid, 0
275
        else:
276
            logits_output, next_token_ids, bid = self.forward_target_extend(batch)
277
278
279
280
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                self.forward_draft_extend(
                    batch, logits_output.hidden_states, next_token_ids
                )
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            return logits_output, next_token_ids, bid, 0

    def forward_target_extend(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int]:
        """Run the target extend.

        Args:
            batch: The batch to run. States could be modified.

        Returns:
            logits_output: The output of logits. It will contain the full hidden states.
            next_token_ids: Next token ids generated.
            bid: The model batch ID. Used for overlap schedule.
        """
        # Forward with the target model and get hidden states.
        # We need the full hidden states to prefill the KV cache of the draft model.
        model_worker_batch = batch.get_model_worker_batch()
        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
        logits_output, next_token_ids = self.target_worker.forward_batch_generation(
            model_worker_batch
        )
        return logits_output, next_token_ids, model_worker_batch.bid
304

Yineng Zhang's avatar
Yineng Zhang committed
305
306
307
308
309
    def draft(self, batch: ScheduleBatch):
        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

310
311
312
313
314
315
316
        # Accumulate penalty
        if batch.sampling_info.penalizer_orchestrator.is_required:
            # This is a relaxed version of penalties for speculative decoding.
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                spec_info.verified_id.to(torch.int64)
            )

Yineng Zhang's avatar
Yineng Zhang committed
317
        # Allocate cache locations
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        if self.page_size == 1:
            out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
                num_seqs * self.topk * self.speculative_num_steps, backup_state=True
            )
        else:
            if self.topk == 1:
                prefix_lens = batch.seq_lens
                seq_lens = prefix_lens + self.speculative_num_steps
                extend_num_tokens = num_seqs * self.speculative_num_steps
            else:
                # In this case, the last partial page needs to be duplicated.
                # KV cache layout in batch.req_to_token_pool.req_to_token:
                #
                # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
                #    prefix     top-k = 0    tok-k = 1    top-k = 2
                #
                #  "-" means prefix tokens
                #  "x" means speculative draft tokens
                #  "." means padded tokens

                # TODO: fuse these ops
                prefix_lens = batch.seq_lens
                last_page_lens = prefix_lens % self.page_size
                num_new_pages = (
                    last_page_lens + self.speculative_num_steps + self.page_size - 1
                ) // self.page_size
                seq_lens = (
                    prefix_lens // self.page_size * self.page_size
                    + num_new_pages * (self.page_size * self.topk)
                )
                extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
                raise NotImplementedError(
                    "page_size > 1 and top_k > 1 are not supported."
                )
                # TODO: Support page_size > 1 and top_k > 1
                # 1. Duplicate the KV cache in the last partial page for all top-k segments
                # 2. Modify generate_draft_decode_kv_indices accordingly

            last_loc = get_last_loc(
                batch.req_to_token_pool.req_to_token,
                batch.req_pool_indices,
                prefix_lens,
            )
            out_cache_loc, token_to_kv_pool_state_backup = (
                batch.alloc_paged_token_slots_extend(
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    extend_num_tokens,
                    backup_state=True,
                )
            )

Yineng Zhang's avatar
Yineng Zhang committed
371
372
373
374
375
376
377
378
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
379
            self.page_size,
Yineng Zhang's avatar
Yineng Zhang committed
380
381
382
383
384
385
386
387
        )
        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)

        # Get forward batch
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
388
389
390
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
Yineng Zhang's avatar
Yineng Zhang committed
391
392
393
394
395
396
397
398
399
400
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )
        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
            # Initialize attention backend
            self.draft_attn_backend.init_forward_metadata(forward_batch)
401
402
403
            forward_batch = ForwardBatch.init_new(
                model_worker_batch, self.draft_model_runner
            )
Yineng Zhang's avatar
Yineng Zhang committed
404
405
406
            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

407
408
        self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

Yineng Zhang's avatar
Yineng Zhang committed
409
410
411
412
413
414
415
416
417
418
419
        ret = EagleVerifyInput.create(
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
            self.server_args.speculative_num_draft_tokens,
        )
420
        return ret
Yineng Zhang's avatar
Yineng Zhang committed
421
422
423
424
425
426
427
428
429
430

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )
431
432
        if self.hot_token_id is not None:
            topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

449
            # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
450
451
452
            if i == self.speculative_num_steps - 1:
                break

Yineng Zhang's avatar
Yineng Zhang committed
453
454
            # Set inputs
            forward_batch.input_ids = input_ids
455
            out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
Yineng Zhang's avatar
Yineng Zhang committed
456
            forward_batch.out_cache_loc = out_cache_loc[
457
458
                :, self.topk * i : self.topk * (i + 1)
            ].flatten()
Yineng Zhang's avatar
Yineng Zhang committed
459
460
461
462
463
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
464
            logits_output = self.draft_model_runner.model.forward(
Yineng Zhang's avatar
Yineng Zhang committed
465
466
                forward_batch.input_ids, forward_batch.positions, forward_batch
            )
467
            self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
468
469
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
470
471
            if self.hot_token_id is not None:
                topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
472
473
474
475
476
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
477
        spec_info.prepare_for_verify(batch, self.page_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
478
        batch.forward_mode = ForwardMode.TARGET_VERIFY
Yineng Zhang's avatar
Yineng Zhang committed
479
        batch.spec_info = spec_info
480
481
482
483
        model_worker_batch = batch.get_model_worker_batch()
        logits_output, _ = self.target_worker.forward_batch_generation(
            model_worker_batch, skip_sample=True
        )
484
        self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
485
        spec_info.hidden_states = logits_output.hidden_states
486
        res: EagleVerifyOutput = spec_info.verify(
487
488
489
490
            batch,
            logits_output,
            self.token_to_kv_pool_allocator,
            self.page_size,
491
492
493
494
495
        )

        # Post process based on verified outputs.
        # Pick indices that we care (accepeted)
        logits_output.next_token_logits = logits_output.next_token_logits[
496
            res.accepeted_indices
497
        ]
498
499
        logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]

500
        # Prepare the batch for the next draft forwards.
501
        batch.forward_mode = ForwardMode.DECODE
502
        batch.spec_info = res.draft_input
503

504
        if batch.return_logprob:
505
            self.add_logprob_values(batch, res, logits_output)
506

507
508
        return logits_output, res, model_worker_batch

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    def add_logprob_values(
        self,
        batch: ScheduleBatch,
        res: EagleVerifyOutput,
        logits_output: LogitsProcessorOutput,
    ):
        # Extract args
        logits_output = res.logits_output
        top_logprobs_nums = batch.top_logprobs_nums
        token_ids_logprobs = batch.token_ids_logprobs
        logprobs = torch.nn.functional.log_softmax(
            logits_output.next_token_logits, dim=-1
        )
        batch_next_token_ids = res.verified_id
        num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]

        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)

        # Extract logprobs
        if any(x > 0 for x in top_logprobs_nums):
            (
                logits_output.next_token_top_logprobs_val,
                logits_output.next_token_top_logprobs_idx,
            ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)

        if any(x is not None for x in token_ids_logprobs):
            (
                logits_output.next_token_token_ids_logprobs_val,
                logits_output.next_token_token_ids_logprobs_idx,
            ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)

        logits_output.next_token_logprobs = logprobs[
            torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
            batch_next_token_ids,
        ]

551
        # Add output logprobs to the request
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        pt = 0
        next_token_logprobs = logits_output.next_token_logprobs.tolist()
        verified_ids = batch_next_token_ids.tolist()
        for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
            for _ in range(num_tokens):
                if req.return_logprob:
                    req.output_token_logprobs_val.append(next_token_logprobs[pt])
                    req.output_token_logprobs_idx.append(verified_ids[pt])
                    if req.top_logprobs_num > 0:
                        req.output_top_logprobs_val.append(
                            res.logits_output.next_token_top_logprobs_val[pt]
                        )
                        req.output_top_logprobs_idx.append(
                            res.logits_output.next_token_top_logprobs_idx[pt]
                        )
                pt += 1

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    def forward_draft_extend(
        self,
        batch: ScheduleBatch,
        hidden_states: torch.Tensor,
        next_token_ids: List[int],
    ):
        """Run draft model extend. This API modifies the states of the batch.

        Args:
            batch: The batch to run.
            hidden_states: Hidden states from the target model forward
            next_token_ids: Next token ids generated from the target forward.
        """
        batch.spec_info = EagleDraftInput(
            hidden_states=hidden_states,
            verified_id=next_token_ids,
        )
Yineng Zhang's avatar
Yineng Zhang committed
586
587
588
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
589
590
591
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
592
        forward_batch.return_logprob = False
593
594
595
596
597
        logits_output = self.draft_model_runner.forward(forward_batch)
        self._detect_nan_if_needed(logits_output)
        assert isinstance(forward_batch.spec_info, EagleDraftInput)
        assert forward_batch.spec_info is batch.spec_info
        self.capture_for_decode(logits_output, forward_batch.spec_info)
598

Lianmin Zheng's avatar
Lianmin Zheng committed
599
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
600
601
602
603
604
605
606
        # Backup fileds that will be modified in-place
        seq_lens_backup = batch.seq_lens.clone()
        req_pool_indices_backup = batch.req_pool_indices
        accept_length_backup = batch.spec_info.accept_length
        return_logprob_backup = batch.return_logprob

        # Prepare metadata
607
        batch.forward_mode = ForwardMode.DRAFT_EXTEND
608
609
610
611
        batch.spec_info.prepare_extend_after_decode(
            batch,
            self.speculative_num_steps,
        )
612
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
613
        batch.return_logprob = False
614
        model_worker_batch = batch.get_model_worker_batch()
615
616
617
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
618
619

        # Run
620
        logits_output = self.draft_model_runner.forward(forward_batch)
621

622
623
        self._detect_nan_if_needed(logits_output)
        self.capture_for_decode(logits_output, forward_batch.spec_info)
624

625
626
627
628
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
        batch.forward_mode = ForwardMode.DECODE
        batch.seq_lens = seq_lens_backup
629
630
631
        batch.req_pool_indices = req_pool_indices_backup
        batch.spec_info.accept_length = accept_length_backup
        batch.return_logprob = return_logprob_backup
632

Lianmin Zheng's avatar
Lianmin Zheng committed
633
    def capture_for_decode(
634
        self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Lianmin Zheng's avatar
Lianmin Zheng committed
635
    ):
Yineng Zhang's avatar
Yineng Zhang committed
636
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
637
638
639
640
        draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
        draft_input.hidden_states = logits_output.hidden_states

    def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
641
        if self.enable_nan_detection:
642
643
            logits = logits_output.next_token_logits
            if torch.any(torch.isnan(logits)):
644
                logger.error("Detected errors during sampling! NaN in the logits.")
645
                raise ValueError("Detected errors during sampling! NaN in the logits.")
646
647
648
649
650
651
652
653
654


def load_token_map(token_map_path: str) -> List[int]:
    if not os.path.exists(token_map_path):
        cache_dir = snapshot_download(
            os.path.dirname(token_map_path),
            ignore_patterns=["*.bin", "*.safetensors"],
        )
        token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
655
    hot_token_id = torch.load(token_map_path, weights_only=True)
656
    return torch.tensor(hot_token_id, dtype=torch.int32)