eagle_worker.py 26.4 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
import logging
2
import os
Yineng Zhang's avatar
Yineng Zhang committed
3
import time
4
from contextlib import contextmanager
5
from typing import List, Optional, Tuple
6
7

import torch
8
from huggingface_hub import snapshot_download
9

10
11
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
12
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
15
16
17
18
from sglang.srt.managers.schedule_batch import (
    ScheduleBatch,
    get_last_loc,
    global_server_args_dict,
)
19
20
21
22
23
24
25
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import (
    CaptureHiddenMode,
    ForwardBatch,
    ForwardMode,
)
from sglang.srt.server_args import ServerArgs
Yineng Zhang's avatar
Yineng Zhang committed
26
27
28
29
30
31
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
    EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
    EagleDraftInput,
    EagleVerifyInput,
32
    EagleVerifyOutput,
Yineng Zhang's avatar
Yineng Zhang committed
33
34
35
    assign_draft_cache_locs,
    select_top_k_tokens,
)
James Liu's avatar
James Liu committed
36
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
37
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
38

39
if is_cuda():
40
    from sgl_kernel import segment_packbits
Yineng Zhang's avatar
Yineng Zhang committed
41
42

logger = logging.getLogger(__name__)
43
44


45
46
47
48
49
50
51
52
@contextmanager
def draft_tp_context(tp_group: GroupCoordinator):
    # Draft model doesn't use dp and has its own tp group.
    # We disable mscclpp now because it doesn't support 2 comm groups.
    with disable_dp_size(), patch_tensor_parallel_group(tp_group):
        yield


53
54
55
56
57
58
59
60
61
62
63
class EAGLEWorker(TpModelWorker):

    def __init__(
        self,
        server_args: ServerArgs,
        gpu_id: int,
        tp_rank: int,
        dp_rank: Optional[int],
        nccl_port: int,
        target_worker: TpModelWorker,
    ):
64
65
66
67
68
69
70
71
72
        # Parse arguments
        self.server_args = server_args
        self.topk = server_args.speculative_eagle_topk
        self.speculative_num_steps = server_args.speculative_num_steps
        self.padded_static_len = self.speculative_num_steps + 1
        self.enable_nan_detection = server_args.enable_nan_detection
        self.gpu_id = gpu_id
        self.device = server_args.device
        self.target_worker = target_worker
73
        self.page_size = server_args.page_size
James Liu's avatar
James Liu committed
74
75
76
        self.speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
77

78
79
80
        # Override context length with target model's context length
        server_args.context_length = target_worker.model_runner.model_config.context_len

applesaucethebun's avatar
applesaucethebun committed
81
        # Do not capture CUDA graph in `super().__init__()`
82
        # It will be captured later.
83
84
        backup_disable_cuda_graph = server_args.disable_cuda_graph
        server_args.disable_cuda_graph = True
85
86
87
88
89
        # Share the allocator with a target worker.
        # Draft and target worker own their own KV cache pools.
        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            target_worker.get_memory_pool()
        )
90

91
        # Load hot token ids
James Liu's avatar
James Liu committed
92
93
94
95
96
97
98
        if self.speculative_algorithm.is_eagle3():
            if server_args.speculative_token_map is not None:
                logger.warning(
                    "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
                )
            self.hot_token_id = None
        elif server_args.speculative_token_map is not None:
99
            self.hot_token_id = load_token_map(server_args.speculative_token_map)
100
101
102
            server_args.json_model_override_args = (
                f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
            )
103
104
        else:
            self.hot_token_id = None
105

106
        # Init draft worker
107
108
        with empty_context():
            super().__init__(
109
                server_args=server_args,
110
111
                gpu_id=gpu_id,
                tp_rank=tp_rank,
112
                pp_rank=0,  # FIXME
113
                dp_rank=dp_rank,
114
                nccl_port=nccl_port,
115
116
117
118
                is_draft_worker=True,
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
Yineng Zhang's avatar
Yineng Zhang committed
119

120
        embed, head = self.target_worker.model_runner.model.get_embed_and_head()
James Liu's avatar
James Liu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        if self.speculative_algorithm.is_eagle3():
            # EAGLE3 models don't share lm_head
            self.draft_model_runner.model.set_embed(embed)

            # grab hot token ids
            self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
                embed.device
            )
        else:
            if self.hot_token_id is not None:
                head = head.clone()
                self.hot_token_id = self.hot_token_id.to(head.device)
                head.data = head.data[self.hot_token_id]

            # Share the embedding and lm_head
            self.draft_model_runner.model.set_embed_and_head(embed, head)
138

applesaucethebun's avatar
applesaucethebun committed
139
        # Init attention backend and CUDA graphs
140
141
142
        self.draft_model_runner.server_args.disable_cuda_graph = (
            backup_disable_cuda_graph
        )
143
144
145
146
147
148
        self.draft_tp_context = (
            draft_tp_context if server_args.enable_dp_attention else empty_context
        )
        with self.draft_tp_context(self.draft_model_runner.tp_group):
            self.init_attention_backend()
            self.init_cuda_graphs()
149
150

    def init_attention_backend(self):
applesaucethebun's avatar
applesaucethebun committed
151
        # Create multi-step attn backends and CUDA graph runners
152
        if self.server_args.attention_backend == "flashinfer":
153
154
155
156
            if not global_server_args_dict["use_mla_backend"]:
                from sglang.srt.layers.attention.flashinfer_backend import (
                    FlashInferMultiStepDraftBackend,
                )
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                self.draft_attn_backend = FlashInferMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
            else:
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
                    FlashInferMLAMultiStepDraftBackend,
                )

                self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
                    self.draft_model_runner,
                    self.topk,
                    self.speculative_num_steps,
                )
173
174
175
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = True
176
        elif self.server_args.attention_backend == "triton":
177
178
179
180
181
            from sglang.srt.layers.attention.triton_backend import (
                TritonMultiStepDraftBackend,
            )

            self.draft_attn_backend = TritonMultiStepDraftBackend(
182
                self.draft_model_runner,
183
184
185
                self.topk,
                self.speculative_num_steps,
            )
186
187
188
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
189
190
191
192
193
194
195
196
197
198
199
200
201
        elif self.server_args.attention_backend == "fa3":
            from sglang.srt.layers.attention.flashattention_backend import (
                FlashAttentionMultiStepBackend,
            )

            self.draft_attn_backend = FlashAttentionMultiStepBackend(
                self.draft_model_runner,
                self.topk,
                self.speculative_num_steps,
            )
            self.draft_extend_attn_backend = None
            self.padded_static_len = self.speculative_num_steps + 1
            self.has_prefill_wrapper_verify = False
202
203
        else:
            raise ValueError(
204
                f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
205
            )
206

207
        self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
Yineng Zhang's avatar
Yineng Zhang committed
208
209

    def init_cuda_graphs(self):
applesaucethebun's avatar
applesaucethebun committed
210
        """Capture CUDA graphs."""
Yineng Zhang's avatar
Yineng Zhang committed
211
        self.cuda_graph_runner = None
212
        self.cuda_graph_runner_for_draft_extend = None
Yineng Zhang's avatar
Yineng Zhang committed
213
214
215
216

        if self.server_args.disable_cuda_graph:
            return

217
        # Capture draft
Yineng Zhang's avatar
Yineng Zhang committed
218
        tic = time.time()
219
        before_mem = get_available_gpu_memory(self.device, self.gpu_id)
220
        logger.info(
applesaucethebun's avatar
applesaucethebun committed
221
            f"Capture draft CUDA graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
222
        )
Yineng Zhang's avatar
Yineng Zhang committed
223
        self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
224
        after_mem = get_available_gpu_memory(self.device, self.gpu_id)
225
        logger.info(
applesaucethebun's avatar
applesaucethebun committed
226
            f"Capture draft CUDA graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
227
        )
228

229
230
231
232
        # Capture extend
        if self.draft_extend_attn_backend:
            raise NotImplementedError()

233
234
235
236
237
238
239
240
241
    @property
    def draft_model_runner(self):
        return self.model_runner

    def forward_batch_speculative_generation(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
        """Run speculative decoding forward.

242
243
        NOTE: Many states of batch is modified as you go through. It is not guaranteed that
        the final output batch have the same state as the input.
244
245
246
247

        Args:
            batch: The batch to run forward. The state of the batch is modified as it runs.
        Returns:
248
249
            A tuple of the final logit output of the target model, next tokens accepted,
            the batch id (used for overlap schedule), and number of accepted tokens.
250
        """
251
        if batch.forward_mode.is_decode():
252
            with self.draft_tp_context(self.draft_model_runner.tp_group):
253
                spec_info = self.draft(batch)
254
255
            logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
                self.verify(batch, spec_info)
256
            )
257
258
259
260
261

            # If it is None, it means all requests are finished
            if batch.spec_info.verified_id is not None:
                with self.draft_tp_context(self.draft_model_runner.tp_group):
                    self.forward_draft_extend_after_decode(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
262
263
            return (
                logits_output,
264
265
266
                verify_output.verified_id,
                model_worker_batch.bid,
                sum(verify_output.accept_length_per_req_cpu),
267
                can_run_cuda_graph,
Lianmin Zheng's avatar
Lianmin Zheng committed
268
            )
269
270
        elif batch.forward_mode.is_idle():
            model_worker_batch = batch.get_model_worker_batch()
271
272
            logits_output, next_token_ids, _ = (
                self.target_worker.forward_batch_generation(model_worker_batch)
273
            )
274

275
            return logits_output, next_token_ids, model_worker_batch.bid, 0, False
276
        else:
277
            logits_output, next_token_ids, bid = self.forward_target_extend(batch)
278
279
280
281
            with self.draft_tp_context(self.draft_model_runner.tp_group):
                self.forward_draft_extend(
                    batch, logits_output.hidden_states, next_token_ids
                )
282
            return logits_output, next_token_ids, bid, 0, False
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    def forward_target_extend(
        self, batch: ScheduleBatch
    ) -> Tuple[LogitsProcessorOutput, List[int], int]:
        """Run the target extend.

        Args:
            batch: The batch to run. States could be modified.

        Returns:
            logits_output: The output of logits. It will contain the full hidden states.
            next_token_ids: Next token ids generated.
            bid: The model batch ID. Used for overlap schedule.
        """
        # Forward with the target model and get hidden states.
        # We need the full hidden states to prefill the KV cache of the draft model.
        model_worker_batch = batch.get_model_worker_batch()
        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
301
        logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
302
303
304
            model_worker_batch
        )
        return logits_output, next_token_ids, model_worker_batch.bid
305

Yineng Zhang's avatar
Yineng Zhang committed
306
307
308
309
310
    def draft(self, batch: ScheduleBatch):
        # Parse args
        num_seqs = batch.batch_size()
        spec_info = batch.spec_info

311
312
313
314
315
316
317
        # Accumulate penalty
        if batch.sampling_info.penalizer_orchestrator.is_required:
            # This is a relaxed version of penalties for speculative decoding.
            batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
                spec_info.verified_id.to(torch.int64)
            )

Yineng Zhang's avatar
Yineng Zhang committed
318
        # Allocate cache locations
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        if self.page_size == 1:
            out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots(
                num_seqs * self.topk * self.speculative_num_steps, backup_state=True
            )
        else:
            if self.topk == 1:
                prefix_lens = batch.seq_lens
                seq_lens = prefix_lens + self.speculative_num_steps
                extend_num_tokens = num_seqs * self.speculative_num_steps
            else:
                # In this case, the last partial page needs to be duplicated.
                # KV cache layout in batch.req_to_token_pool.req_to_token:
                #
                # | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
                #    prefix     top-k = 0    tok-k = 1    top-k = 2
                #
                #  "-" means prefix tokens
                #  "x" means speculative draft tokens
                #  "." means padded tokens

                # TODO: fuse these ops
                prefix_lens = batch.seq_lens
                last_page_lens = prefix_lens % self.page_size
                num_new_pages = (
                    last_page_lens + self.speculative_num_steps + self.page_size - 1
                ) // self.page_size
                seq_lens = (
                    prefix_lens // self.page_size * self.page_size
                    + num_new_pages * (self.page_size * self.topk)
                )
                extend_num_tokens = torch.sum(seq_lens - prefix_lens).item()
                raise NotImplementedError(
                    "page_size > 1 and top_k > 1 are not supported."
                )
                # TODO: Support page_size > 1 and top_k > 1
                # 1. Duplicate the KV cache in the last partial page for all top-k segments
                # 2. Modify generate_draft_decode_kv_indices accordingly

            last_loc = get_last_loc(
                batch.req_to_token_pool.req_to_token,
                batch.req_pool_indices,
                prefix_lens,
            )
            out_cache_loc, token_to_kv_pool_state_backup = (
                batch.alloc_paged_token_slots_extend(
                    prefix_lens,
                    seq_lens,
                    last_loc,
                    extend_num_tokens,
                    backup_state=True,
                )
            )

Yineng Zhang's avatar
Yineng Zhang committed
372
373
374
375
376
377
378
379
        assign_draft_cache_locs[(num_seqs,)](
            batch.req_pool_indices,
            batch.req_to_token_pool.req_to_token,
            batch.seq_lens,
            out_cache_loc,
            batch.req_to_token_pool.req_to_token.shape[1],
            self.topk,
            self.speculative_num_steps,
380
            self.page_size,
Yineng Zhang's avatar
Yineng Zhang committed
381
382
383
384
385
386
387
388
        )
        batch.out_cache_loc = out_cache_loc
        batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
        spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)

        # Get forward batch
        spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
389
390
391
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
Yineng Zhang's avatar
Yineng Zhang committed
392
393
394
395
396
397
398
399
400
401
        can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
            forward_batch
        )
        if can_cuda_graph:
            score_list, token_list, parents_list = self.cuda_graph_runner.replay(
                forward_batch
            )
        else:
            # Initialize attention backend
            self.draft_attn_backend.init_forward_metadata(forward_batch)
402
403
404
            forward_batch = ForwardBatch.init_new(
                model_worker_batch, self.draft_model_runner
            )
Yineng Zhang's avatar
Yineng Zhang committed
405
406
407
            # Run forward steps
            score_list, token_list, parents_list = self.draft_forward(forward_batch)

408
409
        self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)

Yineng Zhang's avatar
Yineng Zhang committed
410
411
412
413
414
415
416
417
418
419
420
        ret = EagleVerifyInput.create(
            spec_info.verified_id,
            score_list,
            token_list,
            parents_list,
            batch.seq_lens,
            batch.seq_lens_sum,
            self.topk,
            self.speculative_num_steps,
            self.server_args.speculative_num_draft_tokens,
        )
421
        return ret
Yineng Zhang's avatar
Yineng Zhang committed
422
423
424
425
426
427
428
429
430
431

    def draft_forward(self, forward_batch: ForwardBatch):
        # Parse args
        spec_info = forward_batch.spec_info
        out_cache_loc = forward_batch.out_cache_loc
        topk_p, topk_index, hidden_states = (
            spec_info.topk_p,
            spec_info.topk_index,
            spec_info.hidden_states,
        )
432
433
        if self.hot_token_id is not None:
            topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449

        # Return values
        score_list: List[torch.Tensor] = []
        token_list: List[torch.Tensor] = []
        parents_list: List[torch.Tensor] = []

        # Forward multiple steps
        scores = None
        for i in range(self.speculative_num_steps):
            input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
                i, topk_p, topk_index, hidden_states, scores, self.topk
            )
            score_list.append(tree_info[0])
            token_list.append(tree_info[1])
            parents_list.append(tree_info[2])

450
            # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
451
452
453
            if i == self.speculative_num_steps - 1:
                break

Yineng Zhang's avatar
Yineng Zhang committed
454
455
            # Set inputs
            forward_batch.input_ids = input_ids
456
            out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
Yineng Zhang's avatar
Yineng Zhang committed
457
            forward_batch.out_cache_loc = out_cache_loc[
458
459
                :, self.topk * i : self.topk * (i + 1)
            ].flatten()
Yineng Zhang's avatar
Yineng Zhang committed
460
461
462
463
464
            forward_batch.positions.add_(1)
            forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
            spec_info.hidden_states = hidden_states

            # Run forward
465
            logits_output = self.draft_model_runner.model.forward(
Yineng Zhang's avatar
Yineng Zhang committed
466
467
                forward_batch.input_ids, forward_batch.positions, forward_batch
            )
468
            self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
469
470
            probs = torch.softmax(logits_output.next_token_logits, dim=-1)
            topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
471
472
            if self.hot_token_id is not None:
                topk_index = self.hot_token_id[topk_index]
Yineng Zhang's avatar
Yineng Zhang committed
473
474
475
476
477
            hidden_states = logits_output.hidden_states

        return score_list, token_list, parents_list

    def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
478
        spec_info.prepare_for_verify(batch, self.page_size)
Lianmin Zheng's avatar
Lianmin Zheng committed
479
        batch.forward_mode = ForwardMode.TARGET_VERIFY
Yineng Zhang's avatar
Yineng Zhang committed
480
        batch.spec_info = spec_info
481
        model_worker_batch = batch.get_model_worker_batch()
482
483
484
485
        logits_output, _, can_run_cuda_graph = (
            self.target_worker.forward_batch_generation(
                model_worker_batch, skip_sample=True
            )
486
        )
487
        self._detect_nan_if_needed(logits_output)
Yineng Zhang's avatar
Yineng Zhang committed
488
        spec_info.hidden_states = logits_output.hidden_states
489
        res: EagleVerifyOutput = spec_info.verify(
490
491
492
493
            batch,
            logits_output,
            self.token_to_kv_pool_allocator,
            self.page_size,
494
495
496
        )

        # Post process based on verified outputs.
497
        # Pick indices that we care (accepted)
498
        logits_output.next_token_logits = logits_output.next_token_logits[
499
            res.accepted_indices
500
        ]
501
        logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
502

503
        # Prepare the batch for the next draft forwards.
504
        batch.forward_mode = ForwardMode.DECODE
505
        batch.spec_info = res.draft_input
506

507
        if batch.return_logprob:
508
            self.add_logprob_values(batch, res, logits_output)
509

510
        return logits_output, res, model_worker_batch, can_run_cuda_graph
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    def add_logprob_values(
        self,
        batch: ScheduleBatch,
        res: EagleVerifyOutput,
        logits_output: LogitsProcessorOutput,
    ):
        # Extract args
        logits_output = res.logits_output
        top_logprobs_nums = batch.top_logprobs_nums
        token_ids_logprobs = batch.token_ids_logprobs
        logprobs = torch.nn.functional.log_softmax(
            logits_output.next_token_logits, dim=-1
        )
        batch_next_token_ids = res.verified_id
        num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]

        # We should repeat top_logprobs_nums to match num_tokens_per_req.
        top_logprobs_nums_repeat_interleaved = []
        token_ids_logprobs_repeat_interleaved = []
        for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
            top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
        for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
            token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)

        # Extract logprobs
        if any(x > 0 for x in top_logprobs_nums):
            (
                logits_output.next_token_top_logprobs_val,
                logits_output.next_token_top_logprobs_idx,
            ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)

        if any(x is not None for x in token_ids_logprobs):
            (
                logits_output.next_token_token_ids_logprobs_val,
                logits_output.next_token_token_ids_logprobs_idx,
            ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)

        logits_output.next_token_logprobs = logprobs[
            torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
            batch_next_token_ids,
        ]

554
        # Add output logprobs to the request
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        pt = 0
        next_token_logprobs = logits_output.next_token_logprobs.tolist()
        verified_ids = batch_next_token_ids.tolist()
        for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
            for _ in range(num_tokens):
                if req.return_logprob:
                    req.output_token_logprobs_val.append(next_token_logprobs[pt])
                    req.output_token_logprobs_idx.append(verified_ids[pt])
                    if req.top_logprobs_num > 0:
                        req.output_top_logprobs_val.append(
                            res.logits_output.next_token_top_logprobs_val[pt]
                        )
                        req.output_top_logprobs_idx.append(
                            res.logits_output.next_token_top_logprobs_idx[pt]
                        )
                pt += 1

572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    def forward_draft_extend(
        self,
        batch: ScheduleBatch,
        hidden_states: torch.Tensor,
        next_token_ids: List[int],
    ):
        """Run draft model extend. This API modifies the states of the batch.

        Args:
            batch: The batch to run.
            hidden_states: Hidden states from the target model forward
            next_token_ids: Next token ids generated from the target forward.
        """
        batch.spec_info = EagleDraftInput(
            hidden_states=hidden_states,
            verified_id=next_token_ids,
        )
Yineng Zhang's avatar
Yineng Zhang committed
589
590
591
        batch.spec_info.prepare_for_extend(batch)
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
        model_worker_batch = batch.get_model_worker_batch()
592
593
594
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
595
        forward_batch.return_logprob = False
596
        logits_output, _ = self.draft_model_runner.forward(forward_batch)
597
598
599
600
        self._detect_nan_if_needed(logits_output)
        assert isinstance(forward_batch.spec_info, EagleDraftInput)
        assert forward_batch.spec_info is batch.spec_info
        self.capture_for_decode(logits_output, forward_batch.spec_info)
601

Lianmin Zheng's avatar
Lianmin Zheng committed
602
    def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
603
        # Backup fields that will be modified in-place
604
605
606
607
608
609
        seq_lens_backup = batch.seq_lens.clone()
        req_pool_indices_backup = batch.req_pool_indices
        accept_length_backup = batch.spec_info.accept_length
        return_logprob_backup = batch.return_logprob

        # Prepare metadata
610
        batch.forward_mode = ForwardMode.DRAFT_EXTEND
611
612
613
614
        batch.spec_info.prepare_extend_after_decode(
            batch,
            self.speculative_num_steps,
        )
615
        batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
616
        batch.return_logprob = False
617
        model_worker_batch = batch.get_model_worker_batch()
618
619
620
        forward_batch = ForwardBatch.init_new(
            model_worker_batch, self.draft_model_runner
        )
621
622

        # Run
623
        logits_output, _ = self.draft_model_runner.forward(forward_batch)
624

625
626
        self._detect_nan_if_needed(logits_output)
        self.capture_for_decode(logits_output, forward_batch.spec_info)
627

628
629
630
631
        # Restore backup.
        # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
        batch.forward_mode = ForwardMode.DECODE
        batch.seq_lens = seq_lens_backup
632
633
634
        batch.req_pool_indices = req_pool_indices_backup
        batch.spec_info.accept_length = accept_length_backup
        batch.return_logprob = return_logprob_backup
635

Lianmin Zheng's avatar
Lianmin Zheng committed
636
    def capture_for_decode(
637
        self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Lianmin Zheng's avatar
Lianmin Zheng committed
638
    ):
Yineng Zhang's avatar
Yineng Zhang committed
639
        probs = torch.softmax(logits_output.next_token_logits, dim=-1)
640
641
642
643
        draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
        draft_input.hidden_states = logits_output.hidden_states

    def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
644
        if self.enable_nan_detection:
645
646
            logits = logits_output.next_token_logits
            if torch.any(torch.isnan(logits)):
647
                logger.error("Detected errors during sampling! NaN in the logits.")
648
                raise ValueError("Detected errors during sampling! NaN in the logits.")
649
650
651
652
653
654
655
656
657


def load_token_map(token_map_path: str) -> List[int]:
    if not os.path.exists(token_map_path):
        cache_dir = snapshot_download(
            os.path.dirname(token_map_path),
            ignore_patterns=["*.bin", "*.safetensors"],
        )
        token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
658
    hot_token_id = torch.load(token_map_path, weights_only=True)
659
    return torch.tensor(hot_token_id, dtype=torch.int32)