eagle.py 52.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
from vllm.config import (
12
    CompilationMode,
13
14
15
16
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
24
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.platforms import current_platform
27
from vllm.utils.platform_utils import is_pin_memory_available
28
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
29
30
31
32
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
33
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
34
35
36
37
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
38
from vllm.v1.kv_cache_interface import KVCacheConfig
39
from vllm.v1.sample.metadata import SamplingMetadata
40
from vllm.v1.sample.sampler import _SAMPLING_EPS
41
42
43
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
44

45
46
logger = init_logger(__name__)

47
48
PADDING_SLOT_ID = -1

49
50
51
52
53
54

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
55
        runner=None,
56
57
    ):
        self.vllm_config = vllm_config
58
        self.speculative_config = vllm_config.speculative_config
59
        assert self.speculative_config is not None
60
61
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
62

Jiayi Yao's avatar
Jiayi Yao committed
63
        self.runner = runner
64
        self.device = device
65
        self.dtype = vllm_config.model_config.dtype
66
67
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
68
69
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
70
        self.token_arange_np = np.arange(self.max_num_tokens)
71
72
73
74
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
75

76
77
78
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
79
80
            vllm_config.model_config
        )
81

82
83
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
84
85
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
86

87
88
        self.use_cuda_graph = False

89
90
91
        self.compilation_config = self.vllm_config.compilation_config
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            cudagraph_mode = self.compilation_config.cudagraph_mode
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )

106
        # persistent buffers for cuda graph
107
108
109
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
110
111
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
112
113
114
115
116
117
118
119
120
121
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
122
            self.mrope_positions = torch.zeros(
123
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
124
            )
125
126
        else:
            # RoPE need (max_num_tokens,)
127
128
129
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
130
        self.hidden_states = torch.zeros(
131
132
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
133

134
135
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
136
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
137
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
138
139
140
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
141

142
        self.inputs_embeds = torch.zeros(
143
144
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
145

146
147
148
149
150
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
151
152
            with_numpy=True,
        )
153

154
        # Determine allowed attention backends once during initialization.
155
156
        from vllm.attention.backends.registry import AttentionBackendEnum

157
        self.allowed_attn_types: tuple | None = None
158
159
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
160
161
162
163
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
164
                from vllm.v1.attention.backends.rocm_aiter_fa import (
165
166
167
                    AiterFlashAttentionMetadata,
                )

168
169
170
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

171
172
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
173
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
174
175
176
177
178
179
180
181
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
182
183
184
185
186
187
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
188
189
190
191
192
193
194
195
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

196
197
198
199
200
201
202
203
204
205
206
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

207
208
209
210
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
211
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
212
213
214
215
216
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
217
        last_token_indices: torch.Tensor | None,
218
        common_attn_metadata: CommonAttentionMetadata,
219
        sampling_metadata: SamplingMetadata,
220
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
221
    ) -> torch.Tensor:
222
223
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
224
225
226

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
227

228
229
230
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
231
232
                target_hidden_states
            )
233
            assert target_hidden_states.shape[-1] == self.hidden_size
234
235
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
236
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
237
238
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
239
        self.input_ids[last_token_indices] = next_token_ids
240

241
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
242

243
244
245
246
247
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

248
        attn_metadata = attn_metadata_builder.build_for_drafting(
249
250
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
251
252
253
254
255
256
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
257
258
                )
            )
259
260
        else:
            draft_indexer_metadata = None
261
262
263
264
265
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
266

267
268
269
270
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

271
        cudagraph_runtime_mode = CUDAGraphMode.NONE
272
273
274
275
        if (
            self.use_cuda_graph
            and num_tokens <= self.compilation_config.max_cudagraph_capture_size
        ):
276
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
277
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
278
279
280
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
281
        self._set_positions(num_tokens, target_positions)
282
        self.hidden_states[:num_tokens] = target_hidden_states
283
284
285
286

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

287
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
288
289
290
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
291
            )
292

293
            input_ids = None
294
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
295
296
        else:
            input_ids = self.input_ids[:num_input_tokens]
297
            inputs_embeds = None
298

299
        with set_forward_context(
300
301
302
303
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
304
        ):
Jiayi Yao's avatar
Jiayi Yao committed
305
            ret_hidden_states = self.model(
306
                input_ids=input_ids,
307
                positions=self._get_positions(num_input_tokens),
308
309
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
310
            )
311
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
312
                last_hidden_states = ret_hidden_states
313
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
314
315
            else:
                last_hidden_states, hidden_states = ret_hidden_states
316
        sample_hidden_states = last_hidden_states[last_token_indices]
317
        logits = self.model.compute_logits(sample_hidden_states)
318
319
320
321
322
323

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

324
325
326
327
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
328
329
330
331
332
333
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
334
335
336
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
337
338
339

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
340
341
342
343
344
345
346
347
348
349
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

350
        draft_token_ids = logits.argmax(dim=-1)
351

352
353
354
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
355
356
357
358
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
359
360
                f"{self.allowed_attn_types}"
            )
361

362
363
364
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

365
366
367
368
        if (
            self.use_cuda_graph
            and batch_size <= self.compilation_config.max_cudagraph_capture_size
        ):
369
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
370
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
371
372
        else:
            input_batch_size = batch_size
373
            cudagraph_runtime_mode = CUDAGraphMode.NONE
374
375
376

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
377
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
378
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
379
380
            self.token_arange_np[: batch_size + 1]
        ).clone()
381
        for token_index in range(self.num_speculative_tokens - 1):
382
            # Update the inputs.
383
384
385
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
386
387
388
389
390
391
392
393
394
395
396
397
398
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
399
400
401
402
403
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
404
405
406
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
407
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
408
409
410
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
411
            # Increment the sequence lengths.
412
            common_attn_metadata.seq_lens += 1
413
414
            # This is an out-of-place operation to avoid modifying the original tensor.
            common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
415
416
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
417

418
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
419

420
            common_attn_metadata.num_computed_tokens_cpu = (
421
                common_attn_metadata.seq_lens_cpu - 1
422
            )
423

424
            # Compute the slot mapping.
425
426
427
428
429
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
430
            block_ids = common_attn_metadata.block_table_tensor.gather(
431
432
                dim=1, index=block_numbers.view(-1, 1)
            )
433
            block_ids = block_ids.view(-1)
434
435
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
436
437
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
438
439
            else:
                common_attn_metadata.slot_mapping = (
440
441
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
442
443
444
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
445
            common_attn_metadata.slot_mapping.masked_fill_(
446
447
                exceeds_max_model_len, PADDING_SLOT_ID
            )
448
449

            # Rebuild attention metadata
450
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
451
452
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
453
454
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
455

456
457
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
458
            self._set_positions(batch_size, clamped_positions)
459
            self.hidden_states[:batch_size] = hidden_states
460
            if self.supports_mm_inputs:
461
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
462

463
                input_ids = None
464
                inputs_embeds = self.inputs_embeds[:input_batch_size]
465
466
            else:
                input_ids = self.input_ids[:input_batch_size]
467
                inputs_embeds = None
468

469
            # Run the model.
470
            with set_forward_context(
471
472
473
474
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
475
            ):
476
                ret_hidden_states = self.model(
477
                    input_ids=input_ids,
478
                    positions=self._get_positions(input_batch_size),
479
480
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
481
                )
482
                if self.method == "mtp":
483
484
485
486
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
487
            hidden_states = hidden_states[:batch_size]
488
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
489
            draft_token_ids = logits.argmax(dim=-1)
490
491
492
493
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
494
        return draft_token_ids
495

496
    def prepare_next_token_ids_cpu(
497
        self,
498
        sampled_token_ids: list[list[int]],
499
500
501
502
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
503
504
505
506
507
508
509
510
511
512
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
513
            if token_ids:
514
515
516
517
518
519
520
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
521
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
522
523
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
524
        next_token_ids = torch.tensor(
525
526
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
527
        return next_token_ids
528

529
530
531
532
533
534
535
536
537
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        discard_request_indices: torch.Tensor,
        num_discarded_requests: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
538
539
540
541
542
543
544
545
546
547
548
549
550
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
        is not sampled and comes from `request.get_token_id()` instead.
        It also accounts for the rejected tokens in `sampled_token_ids`.
        This function must use device functions to operate on the inputs, and
        should not introduce any blocking CPU-GPU synchronization.
        """
        # TODO(Ben): Combine this into a custom fused kernel

        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
551
552
553
554
555
556
557
558
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
            ]
        )
559
560
561
        self.backup_next_token_ids.copy_to_gpu(num_reqs)

        # Mask out the sampled tokens indices that should not be sampled.
562
563
564
        discard_sampled_tokens_req_indices = discard_request_indices[
            :num_discarded_requests
        ]
565
566
567

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        valid_sampled_token_ids_gpu.index_fill_(
568
569
            0, discard_sampled_tokens_req_indices, -1
        )
570
571

        # Generate a mask for all valid tokens within those requests
572
573
574
        valid_mask = (valid_sampled_token_ids_gpu != -1) & (
            valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
        )
575
576
577
578
579
580
581
582
583
584
585

        # Count the number of valid tokens in each request
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Get the rightmost valid index per row
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Get last valid token from each row
        # (assume undefined state where there is no valid token)
        selected_tokens = torch.gather(
586
587
            valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
        ).squeeze(1)
588
589
590
591

        # Use last token if valid, pre-computed backup if not
        batch_size = valid_sampled_token_ids_gpu.shape[0]
        next_token_ids = torch.where(
592
593
594
595
            last_valid_indices != -1,
            selected_tokens,
            self.backup_next_token_ids.gpu[:batch_size],
        )
596
597
598

        return next_token_ids, valid_sampled_tokens_count

599
600
601
602
603
604
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
605
606
607
608
609
610
611
612
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        No blocking CPU operations should be introduced in this function.
        """
613
614
615
616
617
618
619
        num_draft_tokens_gpu = torch.cat(
            [
                spec_decode_metadata.cu_num_draft_tokens[0:1],
                spec_decode_metadata.cu_num_draft_tokens[1:]
                - spec_decode_metadata.cu_num_draft_tokens[:-1],
            ]
        )
620
621
622
623

        num_rejected_tokens_gpu = torch.where(
            num_draft_tokens_gpu > 0,
            num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
624
625
            torch.zeros_like(num_draft_tokens_gpu),
        )
626
627
628

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

629
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
630
631
632
633
634
635
636
637
638

        total_num_tokens = query_start_loc_cpu[-1].item()
        token_indices = self.arange[:total_num_tokens]

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
639
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
640
641
642
643
644
645
646
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
            causal=True,
647
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
648
649
        )

650
651
652
        token_indices_to_sample = (
            common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
        )
653
654
655

        return spec_common_attn_metadata, token_indices, token_indices_to_sample

656
657
658
659
660
661
662
663
664
665
666
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
667
668
669
670
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
671

672
        total_num_drafts = self.cu_drafts_per_level[0]
673
674
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
675
        num_children = self.child_drafts_per_level[0]
676
677
678
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
679
680
681
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
682
683
684
685
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
686
687
688
689
690
691
692
693
694
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
695
696
        # Precompute the draft token positions.
        flattened_draft_positions = (
697
698
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
699
        tree_depth = len(self.cu_drafts_per_level)
700
        for level in range(tree_depth - 1):
701
702
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
703
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
704
705
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
706
            draft_positions = torch.where(
707
708
709
                exceeds_max_model_len,
                0,
                draft_positions,
710
711
            ).view(batch_size, -1)

712
713
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
714
                draft_positions = draft_positions.repeat_interleave(
715
716
                    level_num_drafts, dim=1
                )
717
718
719
720

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
721
722
                    num_children, dim=1
                )
723
724

            # Concatenate the draft tokens, positions, and hidden states.
725
726
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
727
            tree_hidden_states = torch.cat(
728
729
                [tree_hidden_states, draft_hidden_states], dim=1
            )
730
731
732

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
733
            query_len = total_num_drafts
734
735
            common_attn_metadata = replace(
                common_attn_metadata,
736
                query_start_loc=query_len * self.arange[: batch_size + 1],
737
738
739
740
741
742
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
743
                draft_index=level + 1,
744
745
746
747
748
749
750
751
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
752
753
754
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
755
756
757
758
759
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
760
            query_positions = flattened_draft_positions[:, level : level + query_len]
761
            block_numbers = query_positions // self.block_size
762
763
764
765
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
766
767
768
769
770
771
772
773
774
775
776
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
777
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
778

779
780
781
782
            if (
                self.use_cuda_graph
                and num_tokens <= self.compilation_config.max_cudagraph_capture_size
            ):
783
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
784
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
785
786
            else:
                num_input_tokens = num_tokens
787
                cudagraph_runtime_mode = CUDAGraphMode.NONE
788
            # Run the model.
789
            with set_forward_context(
790
791
792
793
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
794
            ):
795
796
797
798
799
800
801
802
803
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
804
805
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
806
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
807
808
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
809
810
811

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
812
813
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
814
815
816
817
818
819

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
820
821
822
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
823
824
825
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
826
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
827
828
829
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

830
    def prepare_inputs(
831
832
        self,
        common_attn_metadata: CommonAttentionMetadata,
833
834
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
835
836
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
837
        This function is used to prepare the inputs for speculative decoding.
838
839
840
841
842
843
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
844
        #       [0, q1, q1 + q2, q1 + q2 + q3]
845
846
847
848
849
850
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
851
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
852
        #  common_attn_metadata.seq_lens{_cpu}:
853
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
854
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
855
856
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
857

858
859
860
861
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
862
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
863

864
865
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
866
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
867
868

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
869
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
870
871
872
873
874
875
876
877
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
878
            dtype=torch.int32,
879
880
            pin_memory=is_pin_memory_available(),
        )
881
882
883
884
885
886
887
888
889
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
890
891
892
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
893
894
895
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
896
897
898
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
899
900
901
902
903
904

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
905
906
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
907
        # Final token indices are:
908
909
910
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
911
        token_indices_np = token_offests + old_query_start_locs_expanded
912
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
913
914

        spec_common_attn_metadata = CommonAttentionMetadata(
915
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
916
917
918
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
919
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
920
921
922
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
923
            max_seq_len=new_seq_lens_cpu.max().item(),
924
925
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
926
            causal=True,
927
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
928
        )
929
930

        return spec_common_attn_metadata, token_indices
931

932
    def get_model_name(self, model: nn.Module) -> str:
933
        if hasattr(model, "module"):  # multi-GPU
934
935
936
            model = model.module
        return model.__class__.__name__

937
    def load_model(self, target_model: nn.Module) -> None:
938
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
939
        target_attn_layer_names = set(
940
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
941
        )
942
943
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
944
945
946
947
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
948

949
        from vllm.compilation.backends import set_model_tag
950

951
        with set_model_tag("eagle_head"):
952
953
954
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
955

956
        draft_attn_layer_names = (
957
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
958
959
960
961
962
963
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
964
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
965
966
967
968
969
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
970
971
972
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
973
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
974
975
976
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
977
978
                )
            )
979
980
        else:
            self.draft_indexer_metadata_builder = None
981

982
        if self.supports_mm_inputs:
983
984
985
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
986
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
987
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
988
989
990
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
991
992
                    "falling back to text-only mode"
                )
993
                self.supports_mm_inputs = False
994

995
996
        if supports_multimodal(target_model):
            # handle multimodality
997
998
999
1000
1001
            if (
                self.get_model_name(target_model)
                == "Qwen2_5_VLForConditionalGeneration"
            ):
                self.model.config.image_token_index = target_model.config.image_token_id
1002
1003
            else:
                self.model.config.image_token_index = (
1004
1005
                    target_model.config.image_token_index
                )
1006
1007
1008
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1009

1010
        # share embed_tokens with the target model if needed
1011
        if get_pp_group().world_size == 1:
1012
            if hasattr(target_language_model.model, "embed_tokens"):
1013
                target_embed_tokens = target_language_model.model.embed_tokens
1014
            elif hasattr(target_language_model.model, "embedding"):
1015
1016
1017
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1018
1019
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1020

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1034
1035
1036
1037
1038
                    and torch.allclose(
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
                        rtol=1e-5,
                        atol=1e-7,
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1052
            else:
1053
1054
                # MTP model
                share_embeddings = True
1055
                logger.info(
1056
1057
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1058
                )
1059
1060
1061
1062
1063

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1064
        else:
1065
            logger.info(
1066
                "The draft model's vocab embedding will be loaded separately"
1067
1068
                " from the target model."
            )
1069
1070

        # share lm_head with the target model if needed
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
                and torch.equal(
                    target_language_model.lm_head.weight, self.model.lm_head.weight
                )
1087
            ):
1088
                share_lm_head = True
1089
                logger.info(
1090
1091
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1092
                )
1093
1094
            else:
                logger.info(
1095
1096
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1097
                )
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1110

1111
1112
1113
1114
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1115
        use_cudagraphs=True,
1116
    ) -> None:
1117
1118
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1119
1120
1121
1122
        if (
            cudagraphs_enabled
            and num_tokens <= self.compilation_config.max_cudagraph_capture_size
        ):
1123
1124
1125
1126
1127
1128
            num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)

        with set_forward_context(
            None,
            self.vllm_config,
            num_tokens=num_tokens,
1129
1130
1131
            cudagraph_runtime_mode=(
                CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
            ),
1132
        ):
1133
            if self.supports_mm_inputs:
1134
1135
1136
1137
1138
1139
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

1140
            self.model(
1141
                input_ids=input_ids,
1142
                positions=self._get_positions(num_tokens),
1143
1144
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
1145
            )
1146

1147
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1148
        """Find and return the attention metadata builders for EAGLE layers.
1149

1150
1151
        Returns:
            The metadata builders for EAGLE layers.
1152

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1168
1169
            "Failed to find attention metadata builder for EAGLE layers."
        )
1170
1171
        return builder

1172
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1194

1195

1196
1197
1198
1199
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1213
1214
1215
1216
1217
1218
1219
1220
1221
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1233
1234
1235
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1236
1237
1238
1239
1240
1241
1242
1243
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs