eagle_worker.py 26.9 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
import logging
2
import os
Yineng Zhang's avatar
Yineng Zhang committed
3
import time
4
from contextlib import contextmanager
5
from typing import List, Optional, Tuple
6
7

import torch
8
from huggingface_hub import snapshot_download
9

10
11
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
12
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
15
16
17
18
from sglang.srt.managers.schedule_batch import (
    ScheduleBatch,
    get_last_loc,
    global_server_args_dict,
)
19
20
21
22
23
24
25
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.server_args import ServerArgs
Yineng Zhang's avatar
Yineng Zhang committed
26
27
28
29
30
31
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
32
    EagleVerifyOutput,
Yineng Zhang's avatar
Yineng Zhang committed
33
34
35
    assign_draft_cache_locs,
    select_top_k_tokens,
)
James Liu's avatar
James Liu committed
36
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
37
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
38

39
if is_cuda():
40
    from sgl_kernel import segment_packbits
Yineng Zhang's avatar
Yineng Zhang committed
41
42

logger = logging.getLogger(__name__)
43
44


45
46
47
48
49
50
51
52
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
    # Draft model doesn't use dp and has its own tp group.
    # We disable mscclpp now because it doesn't support 2 comm groups.
    with disable_dp_size(), patch_tensor_parallel_group(tp_group):
        yield


53
54
55
56
57
58
59
60
61
62
63
class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
64
65
66
67
68
69
70
71
72
        # Parse arguments
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.padded_static_len = self.speculative_num_steps + 1
        self.enable_nan_detection = server_args.enable_nan_detection
        self.gpu_id = gpu_id
        self.device = server_args.device
        self.target_worker = target_worker
73
        self.page_size = server_args.page_size
James Liu's avatar
James Liu committed
74
75
76
        self.speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
77

78
79
80
        # Override context length with target model's context length
        server_args.context_length = target_worker.model_runner.model_config.context_len

81
        # Do not capture cuda graph in `super().__init__()`
82
        # It will be captured later.
83
84
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
85
86
87
88
89
        # Share the allocator with a target worker.
        # Draft and target worker own their own KV cache pools.
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )
90

91
        # Load hot token ids
James Liu's avatar
James Liu committed
92
93
94
95
96
97
98
        if self.speculative_algorithm.is_eagle3():
            if server_args.speculative_token_map is not None:
                logger.warning(
                    "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
                )
            self.hot_token_id = None
        elif server_args.speculative_token_map is not None:
99
            self.hot_token_id = load_token_map(server_args.speculative_token_map)
100
101
102
            server_args.json_model_override_args = (
                f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
            )
103
104
        else:
            self.hot_token_id = None
105

106
        # Init draft worker
107
108
        with empty_context():
            super().__init__(
109
                server_args=server_args,
110
111
                gpu_id=gpu_id,
                tp_rank=tp_rank,
112
                pp_rank=0,  # FIXME
113
                dp_rank=dp_rank,
114
                nccl_port=nccl_port,
115
116
117
118
                is_draft_worker=True,
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
Yineng Zhang's avatar
Yineng Zhang committed
119

120
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
James Liu's avatar
James Liu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        if self.speculative_algorithm.is_eagle3():
            # EAGLE3 models don't share lm_head
            self.draft_model_runner.model.set_embed(embed)

            # grab hot token ids
            self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
                embed.device
            )
        else:
            if self.hot_token_id is not None:
                head = head.clone()
                self.hot_token_id = self.hot_token_id.to(head.device)
                head.data = head.data[self.hot_token_id]

            # Share the embedding and lm_head
            self.draft_model_runner.model.set_embed_and_head(embed, head)
138

139
        # Init attention backend and cuda graphs
140
141
142
        self.draft_model_runner.server_args.disable_cuda_graph = (
            backup_disable_cuda_graph
        )
143
144
145
146
147
148
        self.draft_tp_context = (
            draft_tp_context if server_args.enable_dp_attention else empty_context
        )
        with self.draft_tp_context(self.draft_model_runner.tp_group):
            self.init_attention_backend()
            self.init_cuda_graphs()
149
150

    def init_attention_backend(self):
151
        # Create multi-step attn backends and cuda graph runners
152
        if self.server_args.attention_backend == "flashinfer":
153
154
155
156
            if not global_server_args_dict["use_mla_backend"]:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferMultiStepDraftBackend,
                )
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAMultiStepDraftBackend,
                )

                self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
173
174
175
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = True
176
        elif self.server_args.attention_backend == "triton":
177
178
179
180
181
            from sglang.srt.layers.attention.triton_backend import (
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
182
                self.draft_model_runner,
183
184
185
                self.topk,
                self.speculative_num_steps,
            )
186
187
188
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
189
190
191
192
193
194
195
196
197
198
199
        elif self.server_args.attention_backend == "fa3":
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionMultiStepBackend,
            )

            self.draft_attn_backend = FlashAttentionMultiStepBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = None
200
201
202
203
204
205
206
207
208
209
210
211
212
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
        elif self.server_args.attention_backend == "flashmla":
            from sglang.srt.layers.attention.flashmla_backend import (
                FlashMLAMultiStepDraftBackend,
            )

            self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = None
213
214
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
215
216
        else:
            raise ValueError(
217
                f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
218
            )
219

220
        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
Yineng Zhang's avatar
Yineng Zhang committed
221
222

    def init_cuda_graphs(self):
223
        """Capture cuda graphs."""
Yineng Zhang's avatar
Yineng Zhang committed
224
        self.cuda_graph_runner = None
225
        self.cuda_graph_runner_for_draft_extend = None
Yineng Zhang's avatar
Yineng Zhang committed
226
227
228
229

        if self.server_args.disable_cuda_graph:
            return

230
        # Capture draft
231
        tic = time.perf_counter()
232
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
233
        logger.info(
234
            f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
235
        )
Yineng Zhang's avatar
Yineng Zhang committed
236
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
237
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
238
        logger.info(
239
            f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
240
        )
241

242
243
244
245
        # Capture extend
        if self.draft_extend_attn_backend:
            raise NotImplementedError()

246
247
248
249
250
251
252
253
254
    @property
    def draft_model_runner(self):
        return self.model_runner

    def forward_batch_speculative_generation(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
        """Run speculative decoding forward.

255
256
        NOTE: Many states of batch is modified as you go through. It is not guaranteed that
        the final output batch have the same state as the input.
257
258
259
260

        Args:
            batch: The batch to run forward. The state of the batch is modified as it runs.
        Returns:
261
262
            A tuple of the final logit output of the target model, next tokens accepted,
            the batch id (used for overlap schedule), and number of accepted tokens.
263
        """
264
        if batch.forward_mode.is_decode():
265
            with self.draft_tp_context(self.draft_model_runner.tp_group):
266
                spec_info = self.draft(batch)
267
268
            logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
                self.verify(batch, spec_info)
269
            )
270
271
272
273
274

            # If it is None, it means all requests are finished
            if batch.spec_info.verified_id is not None:
                with self.draft_tp_context(self.draft_model_runner.tp_group):
                    self.forward_draft_extend_after_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
275
276
            return (
                logits_output,
277
278
279
                verify_output.verified_id,
                model_worker_batch.bid,
                sum(verify_output.accept_length_per_req_cpu),
280
                can_run_cuda_graph,
Lianmin Zheng's avatar
Lianmin Zheng committed
281
            )
282
283
        elif batch.forward_mode.is_idle():
            model_worker_batch = batch.get_model_worker_batch()
284
285
            logits_output, next_token_ids, _ = (
                self.target_worker.forward_batch_generation(model_worker_batch)
286
            )
287

288
            return logits_output, next_token_ids, model_worker_batch.bid, 0, False
289
        else:
290
            logits_output, next_token_ids, bid = self.forward_target_extend(batch)
291
292
293
294
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                self.forward_draft_extend(
                    batch, logits_output.hidden_states, next_token_ids
                )
295
            return logits_output, next_token_ids, bid, 0, False
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

    def forward_target_extend(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int]:
        """Run the target extend.

        Args:
            batch: The batch to run. States could be modified.

        Returns:
            logits_output: The output of logits. It will contain the full hidden states.
            next_token_ids: Next token ids generated.
            bid: The model batch ID. Used for overlap schedule.
        """
        # Forward with the target model and get hidden states.
        # We need the full hidden states to prefill the KV cache of the draft model.
        model_worker_batch = batch.get_model_worker_batch()
        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
314
        logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
315
316
317
            model_worker_batch
        )
        return logits_output, next_token_ids, model_worker_batch.bid
318

Yineng Zhang's avatar
Yineng Zhang committed
319
320
321
322
323
    def draft(self, batch: ScheduleBatch):
        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

324
325
326
327
328
329
330
        # Accumulate penalty
        if batch.sampling_info.penalizer_orchestrator.is_required:
            # This is a relaxed version of penalties for speculative decoding.
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                spec_info.verified_id.to(torch.int64)
            )

Yineng Zhang's avatar
Yineng Zhang committed
331
        # Allocate cache locations
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        if self.page_size == 1:
            out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
                num_seqs * self.topk * self.speculative_num_steps, backup_state=True
            )
        else:
            if self.topk == 1:
                prefix_lens = batch.seq_lens
                seq_lens = prefix_lens + self.speculative_num_steps
                extend_num_tokens = num_seqs * self.speculative_num_steps
            else:
                # In this case, the last partial page needs to be duplicated.
                # KV cache layout in batch.req_to_token_pool.req_to_token:
                #
                # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
                #    prefix     top-k = 0    tok-k = 1    top-k = 2
                #
                #  "-" means prefix tokens
                #  "x" means speculative draft tokens
                #  "." means padded tokens

                # TODO: fuse these ops
                prefix_lens = batch.seq_lens
                last_page_lens = prefix_lens % self.page_size
                num_new_pages = (
                    last_page_lens + self.speculative_num_steps + self.page_size - 1
                ) // self.page_size
                seq_lens = (
                    prefix_lens // self.page_size * self.page_size
                    + num_new_pages * (self.page_size * self.topk)
                )
                extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
                raise NotImplementedError(
                    "page_size > 1 and top_k > 1 are not supported."
                )
                # TODO: Support page_size > 1 and top_k > 1
                # 1. Duplicate the KV cache in the last partial page for all top-k segments
                # 2. Modify generate_draft_decode_kv_indices accordingly

            last_loc = get_last_loc(
                batch.req_to_token_pool.req_to_token,
                batch.req_pool_indices,
                prefix_lens,
            )
            out_cache_loc, token_to_kv_pool_state_backup = (
                batch.alloc_paged_token_slots_extend(
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    extend_num_tokens,
                    backup_state=True,
                )
            )

Yineng Zhang's avatar
Yineng Zhang committed
385
386
387
388
389
390
391
392
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
393
            self.page_size,
Yineng Zhang's avatar
Yineng Zhang committed
394
395
396
397
398
399
400
401
        )
        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)

        # Get forward batch
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
402
403
404
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
Yineng Zhang's avatar
Yineng Zhang committed
405
406
407
408
409
410
411
412
413
414
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )
        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
            # Initialize attention backend
            self.draft_attn_backend.init_forward_metadata(forward_batch)
415
416
417
            forward_batch = ForwardBatch.init_new(
                model_worker_batch, self.draft_model_runner
            )
Yineng Zhang's avatar
Yineng Zhang committed
418
419
420
            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

421
422
        self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

Yineng Zhang's avatar
Yineng Zhang committed
423
424
425
426
427
428
429
430
431
432
433
        ret = EagleVerifyInput.create(
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
            self.server_args.speculative_num_draft_tokens,
        )
434
        return ret
Yineng Zhang's avatar
Yineng Zhang committed
435
436
437
438
439
440
441
442
443
444

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )
445
446
        if self.hot_token_id is not None:
            topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

463
            # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
464
465
466
            if i == self.speculative_num_steps - 1:
                break

Yineng Zhang's avatar
Yineng Zhang committed
467
468
            # Set inputs
            forward_batch.input_ids = input_ids
469
            out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
Yineng Zhang's avatar
Yineng Zhang committed
470
            forward_batch.out_cache_loc = out_cache_loc[
471
472
                :, self.topk * i : self.topk * (i + 1)
            ].flatten()
Yineng Zhang's avatar
Yineng Zhang committed
473
474
475
476
477
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
478
            logits_output = self.draft_model_runner.model.forward(
Yineng Zhang's avatar
Yineng Zhang committed
479
480
                forward_batch.input_ids, forward_batch.positions, forward_batch
            )
481
            self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
482
483
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
484
485
            if self.hot_token_id is not None:
                topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
486
487
488
489
490
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
491
        spec_info.prepare_for_verify(batch, self.page_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
492
        batch.forward_mode = ForwardMode.TARGET_VERIFY
Yineng Zhang's avatar
Yineng Zhang committed
493
        batch.spec_info = spec_info
494
        model_worker_batch = batch.get_model_worker_batch()
495
496
497
498
        logits_output, _, can_run_cuda_graph = (
            self.target_worker.forward_batch_generation(
                model_worker_batch, skip_sample=True
            )
499
        )
500
        self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
501
        spec_info.hidden_states = logits_output.hidden_states
502
        res: EagleVerifyOutput = spec_info.verify(
503
504
505
506
            batch,
            logits_output,
            self.token_to_kv_pool_allocator,
            self.page_size,
507
508
509
        )

        # Post process based on verified outputs.
510
        # Pick indices that we care (accepted)
511
        logits_output.next_token_logits = logits_output.next_token_logits[
512
            res.accepted_indices
513
        ]
514
        logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
515

516
        # Prepare the batch for the next draft forwards.
517
        batch.forward_mode = ForwardMode.DECODE
518
        batch.spec_info = res.draft_input
519

520
        if batch.return_logprob:
521
            self.add_logprob_values(batch, res, logits_output)
522

523
        return logits_output, res, model_worker_batch, can_run_cuda_graph
524

525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def add_logprob_values(
        self,
        batch: ScheduleBatch,
        res: EagleVerifyOutput,
        logits_output: LogitsProcessorOutput,
    ):
        # Extract args
        logits_output = res.logits_output
        top_logprobs_nums = batch.top_logprobs_nums
        token_ids_logprobs = batch.token_ids_logprobs
        logprobs = torch.nn.functional.log_softmax(
            logits_output.next_token_logits, dim=-1
        )
        batch_next_token_ids = res.verified_id
        num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]

        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)

        # Extract logprobs
        if any(x > 0 for x in top_logprobs_nums):
            (
                logits_output.next_token_top_logprobs_val,
                logits_output.next_token_top_logprobs_idx,
            ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)

        if any(x is not None for x in token_ids_logprobs):
            (
                logits_output.next_token_token_ids_logprobs_val,
                logits_output.next_token_token_ids_logprobs_idx,
            ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)

        logits_output.next_token_logprobs = logprobs[
            torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
            batch_next_token_ids,
        ]

567
        # Add output logprobs to the request
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
        pt = 0
        next_token_logprobs = logits_output.next_token_logprobs.tolist()
        verified_ids = batch_next_token_ids.tolist()
        for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
            for _ in range(num_tokens):
                if req.return_logprob:
                    req.output_token_logprobs_val.append(next_token_logprobs[pt])
                    req.output_token_logprobs_idx.append(verified_ids[pt])
                    if req.top_logprobs_num > 0:
                        req.output_top_logprobs_val.append(
                            res.logits_output.next_token_top_logprobs_val[pt]
                        )
                        req.output_top_logprobs_idx.append(
                            res.logits_output.next_token_top_logprobs_idx[pt]
                        )
                pt += 1

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    def forward_draft_extend(
        self,
        batch: ScheduleBatch,
        hidden_states: torch.Tensor,
        next_token_ids: List[int],
    ):
        """Run draft model extend. This API modifies the states of the batch.

        Args:
            batch: The batch to run.
            hidden_states: Hidden states from the target model forward
            next_token_ids: Next token ids generated from the target forward.
        """
        batch.spec_info = EagleDraftInput(
            hidden_states=hidden_states,
            verified_id=next_token_ids,
        )
Yineng Zhang's avatar
Yineng Zhang committed
602
603
604
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
605
606
607
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
608
        forward_batch.return_logprob = False
609
        logits_output, _ = self.draft_model_runner.forward(forward_batch)
610
611
612
613
        self._detect_nan_if_needed(logits_output)
        assert isinstance(forward_batch.spec_info, EagleDraftInput)
        assert forward_batch.spec_info is batch.spec_info
        self.capture_for_decode(logits_output, forward_batch.spec_info)
614

Lianmin Zheng's avatar
Lianmin Zheng committed
615
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
616
        # Backup fields that will be modified in-place
617
618
619
620
621
622
        seq_lens_backup = batch.seq_lens.clone()
        req_pool_indices_backup = batch.req_pool_indices
        accept_length_backup = batch.spec_info.accept_length
        return_logprob_backup = batch.return_logprob

        # Prepare metadata
623
        batch.forward_mode = ForwardMode.DRAFT_EXTEND
624
625
626
627
        batch.spec_info.prepare_extend_after_decode(
            batch,
            self.speculative_num_steps,
        )
628
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
629
        batch.return_logprob = False
630
        model_worker_batch = batch.get_model_worker_batch()
631
632
633
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
634
635

        # Run
636
        logits_output, _ = self.draft_model_runner.forward(forward_batch)
637

638
639
        self._detect_nan_if_needed(logits_output)
        self.capture_for_decode(logits_output, forward_batch.spec_info)
640

641
642
643
644
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
        batch.forward_mode = ForwardMode.DECODE
        batch.seq_lens = seq_lens_backup
645
646
647
        batch.req_pool_indices = req_pool_indices_backup
        batch.spec_info.accept_length = accept_length_backup
        batch.return_logprob = return_logprob_backup
648

Lianmin Zheng's avatar
Lianmin Zheng committed
649
    def capture_for_decode(
650
        self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Lianmin Zheng's avatar
Lianmin Zheng committed
651
    ):
Yineng Zhang's avatar
Yineng Zhang committed
652
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
653
654
655
656
        draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
        draft_input.hidden_states = logits_output.hidden_states

    def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
657
        if self.enable_nan_detection:
658
659
            logits = logits_output.next_token_logits
            if torch.any(torch.isnan(logits)):
660
                logger.error("Detected errors during sampling! NaN in the logits.")
661
                raise ValueError("Detected errors during sampling! NaN in the logits.")
662
663
664
665
666
667
668
669
670


def load_token_map(token_map_path: str) -> List[int]:
    if not os.path.exists(token_map_path):
        cache_dir = snapshot_download(
            os.path.dirname(token_map_path),
            ignore_patterns=["*.bin", "*.safetensors"],
        )
        token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
671
    hot_token_id = torch.load(token_map_path, weights_only=True)
672
    return torch.tensor(hot_token_id, dtype=torch.int32)