eagle.py 53.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
from vllm.config import (
12
    CompilationMode,
13
14
15
16
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
17
from vllm.distributed.parallel_state import get_pp_group
18
from vllm.forward_context import set_forward_context
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
21
from vllm.model_executor.model_loader import get_model
22
from vllm.model_executor.models import supports_multimodal
23
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
24
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.platforms import current_platform
27
from vllm.utils.platform_utils import is_pin_memory_available
28
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
29
30
31
32
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
33
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
34
35
36
37
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
38
from vllm.v1.kv_cache_interface import KVCacheConfig
39
from vllm.v1.sample.metadata import SamplingMetadata
40
from vllm.v1.sample.sampler import _SAMPLING_EPS
41
42
43
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
44

45
46
logger = init_logger(__name__)

47
48
PADDING_SLOT_ID = -1

49
50
51
52
53
54

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
55
        runner=None,
56
57
    ):
        self.vllm_config = vllm_config
58
        self.speculative_config = vllm_config.speculative_config
59
        assert self.speculative_config is not None
60
61
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
62

Jiayi Yao's avatar
Jiayi Yao committed
63
        self.runner = runner
64
        self.device = device
65
        self.dtype = vllm_config.model_config.dtype
66
67
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
68
69
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
70
        self.token_arange_np = np.arange(self.max_num_tokens)
71
72
73
74
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
75

76
77
78
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
79
80
            vllm_config.model_config
        )
81

82
83
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
84
85
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
86
87
88
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
89

90
91
        self.use_cuda_graph = False

92
93
94
        self.compilation_config = self.vllm_config.compilation_config
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            cudagraph_mode = self.compilation_config.cudagraph_mode
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )

109
        # persistent buffers for cuda graph
110
111
112
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
113
114
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
115
116
117
118
119
120
121
122
123
124
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
125
            self.mrope_positions = torch.zeros(
126
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
127
            )
128
129
        else:
            # RoPE need (max_num_tokens,)
130
131
132
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
133
        self.hidden_states = torch.zeros(
134
135
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
136

137
138
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
139
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
140
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
141
142
143
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
144

145
        self.inputs_embeds = torch.zeros(
146
147
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
148

149
150
151
152
153
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
154
155
            with_numpy=True,
        )
156

157
        # Determine allowed attention backends once during initialization.
158
159
        from vllm.attention.backends.registry import AttentionBackendEnum

160
        self.allowed_attn_types: tuple | None = None
161
162
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
163
164
165
166
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
167
                from vllm.v1.attention.backends.rocm_aiter_fa import (
168
169
170
                    AiterFlashAttentionMetadata,
                )

171
172
173
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

174
175
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
176
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
177
178
179
180
181
182
183
184
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
185
186
187
188
189
190
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
191
192
193
194
195
196
197
198
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

199
200
201
202
203
204
205
206
207
208
209
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

210
211
212
213
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
214
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
215
216
217
218
219
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
220
        last_token_indices: torch.Tensor | None,
221
        common_attn_metadata: CommonAttentionMetadata,
222
        sampling_metadata: SamplingMetadata,
223
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
224
    ) -> torch.Tensor:
225
226
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
227
228
229

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
230

231
232
233
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
234
235
                target_hidden_states
            )
236
            assert target_hidden_states.shape[-1] == self.hidden_size
237
238
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
239
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
240
241
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
242
        self.input_ids[last_token_indices] = next_token_ids
243

244
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
245

246
247
248
249
250
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

251
        attn_metadata = attn_metadata_builder.build_for_drafting(
252
253
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
254
255
256
257
258
259
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
260
261
                )
            )
262
263
        else:
            draft_indexer_metadata = None
264
265
266
267
268
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
269

270
271
272
273
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

274
        cudagraph_runtime_mode = CUDAGraphMode.NONE
275
276
277
278
        if (
            self.use_cuda_graph
            and num_tokens <= self.compilation_config.max_cudagraph_capture_size
        ):
279
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
280
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
281
282
283
        else:
            num_input_tokens = num_tokens
        # copy inputs to buffer for cudagraph
284
        self._set_positions(num_tokens, target_positions)
285
        self.hidden_states[:num_tokens] = target_hidden_states
286
287
288
289

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

290
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
291
292
293
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
294
            )
295

296
            input_ids = None
297
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
298
299
        else:
            input_ids = self.input_ids[:num_input_tokens]
300
            inputs_embeds = None
301

302
        with set_forward_context(
303
304
305
306
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
            cudagraph_runtime_mode=cudagraph_runtime_mode,
307
        ):
Jiayi Yao's avatar
Jiayi Yao committed
308
            ret_hidden_states = self.model(
309
                input_ids=input_ids,
310
                positions=self._get_positions(num_input_tokens),
311
312
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
313
            )
314
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
315
                last_hidden_states = ret_hidden_states
316
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
317
318
            else:
                last_hidden_states, hidden_states = ret_hidden_states
319
        sample_hidden_states = last_hidden_states[last_token_indices]
320
        logits = self.model.compute_logits(sample_hidden_states)
321
322
323
324
325
326

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

327
328
329
330
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
331
332
333
334
335
336
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
337
338
339
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
340
341
342

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
343
344
345
346
347
348
349
350
351
352
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

353
        draft_token_ids = logits.argmax(dim=-1)
354

355
356
357
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
358
359
360
361
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
362
363
                f"{self.allowed_attn_types}"
            )
364

365
366
367
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

368
369
370
371
        if (
            self.use_cuda_graph
            and batch_size <= self.compilation_config.max_cudagraph_capture_size
        ):
372
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
373
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
374
375
        else:
            input_batch_size = batch_size
376
            cudagraph_runtime_mode = CUDAGraphMode.NONE
377
378
379

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
380
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
381
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
382
383
            self.token_arange_np[: batch_size + 1]
        ).clone()
384
        for token_index in range(self.num_speculative_tokens - 1):
385
            # Update the inputs.
386
387
388
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
389
390
391
392
393
394
395
396
397
398
399
400
401
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
402
403
404
405
406
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
407
408
409
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
410
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
411
412
413
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
414
            # Increment the sequence lengths.
415
            common_attn_metadata.seq_lens += 1
416
417
            # This is an out-of-place operation to avoid modifying the original tensor.
            common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
418
419
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
420

421
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
422

423
            common_attn_metadata.num_computed_tokens_cpu = (
424
                common_attn_metadata.seq_lens_cpu - 1
425
            )
426

427
            # Compute the slot mapping.
428
429
430
431
432
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
433
            block_ids = common_attn_metadata.block_table_tensor.gather(
434
435
                dim=1, index=block_numbers.view(-1, 1)
            )
436
            block_ids = block_ids.view(-1)
437
438
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
439
440
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
441
442
            else:
                common_attn_metadata.slot_mapping = (
443
444
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
445
446
447
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
448
            common_attn_metadata.slot_mapping.masked_fill_(
449
450
                exceeds_max_model_len, PADDING_SLOT_ID
            )
451
452

            # Rebuild attention metadata
453
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
454
455
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
456
457
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
458

459
460
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
461
            self._set_positions(batch_size, clamped_positions)
462
            self.hidden_states[:batch_size] = hidden_states
463
            if self.supports_mm_inputs:
464
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
465

466
                input_ids = None
467
                inputs_embeds = self.inputs_embeds[:input_batch_size]
468
469
            else:
                input_ids = self.input_ids[:input_batch_size]
470
                inputs_embeds = None
471

472
            # Run the model.
473
            with set_forward_context(
474
475
476
477
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
478
            ):
479
                ret_hidden_states = self.model(
480
                    input_ids=input_ids,
481
                    positions=self._get_positions(input_batch_size),
482
483
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
484
                )
485
                if self.method == "mtp":
486
487
488
489
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
490
            hidden_states = hidden_states[:batch_size]
491
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
492
            draft_token_ids = logits.argmax(dim=-1)
493
494
495
496
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
497
        return draft_token_ids
498

499
    def prepare_next_token_ids_cpu(
500
        self,
501
        sampled_token_ids: list[list[int]],
502
503
504
505
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
506
507
508
509
510
511
512
513
514
515
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
516
            if token_ids:
517
518
519
520
521
522
523
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
524
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
525
526
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
527
        next_token_ids = torch.tensor(
528
529
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
530
        return next_token_ids
531

532
533
534
535
536
537
538
539
540
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        discard_request_indices: torch.Tensor,
        num_discarded_requests: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
541
542
543
544
545
546
547
548
549
550
551
552
553
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
        is not sampled and comes from `request.get_token_id()` instead.
        It also accounts for the rejected tokens in `sampled_token_ids`.
        This function must use device functions to operate on the inputs, and
        should not introduce any blocking CPU-GPU synchronization.
        """
        # TODO(Ben): Combine this into a custom fused kernel

        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
554
555
556
557
558
559
560
561
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
            ]
        )
562
563
564
        self.backup_next_token_ids.copy_to_gpu(num_reqs)

        # Mask out the sampled tokens indices that should not be sampled.
565
566
567
        discard_sampled_tokens_req_indices = discard_request_indices[
            :num_discarded_requests
        ]
568
569
570

        valid_sampled_token_ids_gpu = sampled_token_ids.clone()
        valid_sampled_token_ids_gpu.index_fill_(
571
572
            0, discard_sampled_tokens_req_indices, -1
        )
573
574

        # Generate a mask for all valid tokens within those requests
575
576
577
        valid_mask = (valid_sampled_token_ids_gpu != -1) & (
            valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
        )
578
579
580
581
582
583
584
585
586
587
588

        # Count the number of valid tokens in each request
        valid_sampled_tokens_count = valid_mask.sum(dim=1)

        # Get the rightmost valid index per row
        last_valid_indices = valid_sampled_tokens_count - 1
        last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)

        # Get last valid token from each row
        # (assume undefined state where there is no valid token)
        selected_tokens = torch.gather(
589
590
            valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
        ).squeeze(1)
591
592
593
594

        # Use last token if valid, pre-computed backup if not
        batch_size = valid_sampled_token_ids_gpu.shape[0]
        next_token_ids = torch.where(
595
596
597
598
            last_valid_indices != -1,
            selected_tokens,
            self.backup_next_token_ids.gpu[:batch_size],
        )
599
600
601

        return next_token_ids, valid_sampled_tokens_count

602
603
604
605
606
607
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
608
609
610
611
612
613
614
615
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        No blocking CPU operations should be introduced in this function.
        """
616
617
618
619
620
621
622
        num_draft_tokens_gpu = torch.cat(
            [
                spec_decode_metadata.cu_num_draft_tokens[0:1],
                spec_decode_metadata.cu_num_draft_tokens[1:]
                - spec_decode_metadata.cu_num_draft_tokens[:-1],
            ]
        )
623
624
625
626

        num_rejected_tokens_gpu = torch.where(
            num_draft_tokens_gpu > 0,
            num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
627
628
            torch.zeros_like(num_draft_tokens_gpu),
        )
629
630
631

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

632
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
633
634
635
636
637
638
639
640
641

        total_num_tokens = query_start_loc_cpu[-1].item()
        token_indices = self.arange[:total_num_tokens]

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
642
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
643
644
645
646
647
648
649
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
            causal=True,
650
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
651
652
        )

653
654
655
        token_indices_to_sample = (
            common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
        )
656
657
658

        return spec_common_attn_metadata, token_indices, token_indices_to_sample

659
660
661
662
663
664
665
666
667
668
669
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
670
671
672
673
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
674

675
        total_num_drafts = self.cu_drafts_per_level[0]
676
677
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
678
        num_children = self.child_drafts_per_level[0]
679
680
681
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
682
683
684
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
685
686
687
688
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
689
690
691
692
693
694
695
696
697
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
698
699
        # Precompute the draft token positions.
        flattened_draft_positions = (
700
701
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
702
        tree_depth = len(self.cu_drafts_per_level)
703
        for level in range(tree_depth - 1):
704
705
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
706
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
707
708
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
709
            draft_positions = torch.where(
710
711
712
                exceeds_max_model_len,
                0,
                draft_positions,
713
714
            ).view(batch_size, -1)

715
716
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
717
                draft_positions = draft_positions.repeat_interleave(
718
719
                    level_num_drafts, dim=1
                )
720
721
722
723

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
724
725
                    num_children, dim=1
                )
726
727

            # Concatenate the draft tokens, positions, and hidden states.
728
729
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
730
            tree_hidden_states = torch.cat(
731
732
                [tree_hidden_states, draft_hidden_states], dim=1
            )
733
734
735

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
736
            query_len = total_num_drafts
737
738
            common_attn_metadata = replace(
                common_attn_metadata,
739
                query_start_loc=query_len * self.arange[: batch_size + 1],
740
741
742
743
744
745
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
746
                draft_index=level + 1,
747
748
749
750
751
752
753
754
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
755
756
757
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
758
759
760
761
762
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
763
            query_positions = flattened_draft_positions[:, level : level + query_len]
764
            block_numbers = query_positions // self.block_size
765
766
767
768
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
769
770
771
772
773
774
775
776
777
778
779
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
780
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
781

782
783
784
785
            if (
                self.use_cuda_graph
                and num_tokens <= self.compilation_config.max_cudagraph_capture_size
            ):
786
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
787
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
788
789
            else:
                num_input_tokens = num_tokens
790
                cudagraph_runtime_mode = CUDAGraphMode.NONE
791
            # Run the model.
792
            with set_forward_context(
793
794
795
796
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
797
            ):
798
799
800
801
802
803
804
805
806
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
807
808
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
809
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
810
811
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
812
813
814

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
815
816
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
817
818
819
820
821
822

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
823
824
825
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
826
827
828
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
829
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
830
831
832
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

833
    def prepare_inputs(
834
835
        self,
        common_attn_metadata: CommonAttentionMetadata,
836
837
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
838
839
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
840
        This function is used to prepare the inputs for speculative decoding.
841
842
843
844
845
846
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
847
        #       [0, q1, q1 + q2, q1 + q2 + q3]
848
849
850
851
852
853
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
854
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
855
        #  common_attn_metadata.seq_lens{_cpu}:
856
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
857
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
858
859
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
860

861
862
863
864
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
865
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
866

867
868
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
869
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
870
871

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
872
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
873
874
875
876
877
878
879
880
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
881
            dtype=torch.int32,
882
883
            pin_memory=is_pin_memory_available(),
        )
884
885
886
887
888
889
890
891
892
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
893
894
895
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
896
897
898
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
899
900
901
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
902
903
904
905
906
907

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
908
909
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
910
        # Final token indices are:
911
912
913
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
914
        token_indices_np = token_offests + old_query_start_locs_expanded
915
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
916
917

        spec_common_attn_metadata = CommonAttentionMetadata(
918
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
919
920
921
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
922
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
923
924
925
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
926
            max_seq_len=new_seq_lens_cpu.max().item(),
927
928
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
929
            causal=True,
930
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
931
        )
932
933

        return spec_common_attn_metadata, token_indices
934

935
    def get_model_name(self, model: nn.Module) -> str:
936
        if hasattr(model, "module"):  # multi-GPU
937
938
939
            model = model.module
        return model.__class__.__name__

940
    def load_model(self, target_model: nn.Module) -> None:
941
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
942
        target_attn_layer_names = set(
943
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
944
        )
945
946
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
947
948
949
950
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
951

952
        from vllm.compilation.backends import set_model_tag
953

954
        with set_model_tag("eagle_head"):
955
956
957
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
958

959
        draft_attn_layer_names = (
960
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
961
962
963
964
965
966
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
967
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
968
969
970
971
972
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
973
974
975
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
976
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
977
978
979
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
980
981
                )
            )
982
983
        else:
            self.draft_indexer_metadata_builder = None
984

985
        if self.supports_mm_inputs:
986
987
988
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
989
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
990
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
991
992
993
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
994
995
                    "falling back to text-only mode"
                )
996
                self.supports_mm_inputs = False
997

998
999
        if supports_multimodal(target_model):
            # handle multimodality
1000
1001
1002
1003
1004
            if (
                self.get_model_name(target_model)
                == "Qwen2_5_VLForConditionalGeneration"
            ):
                self.model.config.image_token_index = target_model.config.image_token_id
1005
1006
            else:
                self.model.config.image_token_index = (
1007
1008
                    target_model.config.image_token_index
                )
1009
1010
1011
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1012

1013
        # share embed_tokens with the target model if needed
1014
        if get_pp_group().world_size == 1:
1015
            if hasattr(target_language_model.model, "embed_tokens"):
1016
                target_embed_tokens = target_language_model.model.embed_tokens
1017
            elif hasattr(target_language_model.model, "embedding"):
1018
1019
1020
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1021
1022
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1023

1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1037
1038
1039
1040
1041
                    and torch.allclose(
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
                        rtol=1e-5,
                        atol=1e-7,
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1055
            else:
1056
1057
                # MTP model
                share_embeddings = True
1058
                logger.info(
1059
1060
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1061
                )
1062
1063
1064
1065
1066

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1067
        else:
1068
            logger.info(
1069
                "The draft model's vocab embedding will be loaded separately"
1070
1071
                " from the target model."
            )
1072
1073

        # share lm_head with the target model if needed
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
                and torch.equal(
                    target_language_model.lm_head.weight, self.model.lm_head.weight
                )
1090
            ):
1091
                share_lm_head = True
1092
                logger.info(
1093
1094
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1095
                )
1096
1097
            else:
                logger.info(
1098
1099
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1100
                )
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1113

1114
1115
1116
1117
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1118
        use_cudagraphs=True,
1119
    ) -> None:
1120
1121
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1122
1123
1124
1125
        if (
            cudagraphs_enabled
            and num_tokens <= self.compilation_config.max_cudagraph_capture_size
        ):
1126
1127
1128
1129
1130
1131
            num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)

        with set_forward_context(
            None,
            self.vllm_config,
            num_tokens=num_tokens,
1132
1133
1134
            cudagraph_runtime_mode=(
                CUDAGraphMode.PIECEWISE if cudagraphs_enabled else CUDAGraphMode.NONE
            ),
1135
        ):
1136
            if self.supports_mm_inputs:
1137
1138
1139
1140
1141
1142
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None

1143
            self.model(
1144
                input_ids=input_ids,
1145
                positions=self._get_positions(num_tokens),
1146
1147
                hidden_states=self.hidden_states[:num_tokens],
                inputs_embeds=inputs_embeds,
1148
            )
1149

1150
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1151
        """Find and return the attention metadata builders for EAGLE layers.
1152

1153
1154
        Returns:
            The metadata builders for EAGLE layers.
1155

1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1171
1172
            "Failed to find attention metadata builder for EAGLE layers."
        )
1173
1174
        return builder

1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1191
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1213

1214

1215
1216
1217
1218
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1232
1233
1234
1235
1236
1237
1238
1239
1240
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1252
1253
1254
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1255
1256
1257
1258
1259
1260
1261
1262
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs