"docs/api/nixl-connect/operation-status.md" did not exist on "f238d23a4d70035442013097bd29e87516379c9e"
eagle.py 57.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.config import (
13
    CompilationMode,
14
15
16
17
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
18
from vllm.distributed.parallel_state import get_pp_group
19
from vllm.forward_context import set_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.model_loader import get_model
23
from vllm.model_executor.models import supports_multimodal
24
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.platforms import current_platform
28
from vllm.triton_utils import triton
29
from vllm.utils.platform_utils import is_pin_memory_available
30
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
31
32
33
34
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
35
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
36
37
38
39
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
40
from vllm.v1.kv_cache_interface import KVCacheConfig
41
from vllm.v1.sample.metadata import SamplingMetadata
42
from vllm.v1.sample.sampler import _SAMPLING_EPS
43
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
44
45
46
47
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
48
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
49
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
50
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
51

52
53
logger = init_logger(__name__)

54
55
PADDING_SLOT_ID = -1

56
57
58
59
60
61

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
62
        runner=None,
63
64
    ):
        self.vllm_config = vllm_config
65
        self.speculative_config = vllm_config.speculative_config
66
        assert self.speculative_config is not None
67
68
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
69

Jiayi Yao's avatar
Jiayi Yao committed
70
        self.runner = runner
71
        self.device = device
72
        self.dtype = vllm_config.model_config.dtype
73
74
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
Rémi Delacourt's avatar
Rémi Delacourt committed
75
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
76
77
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
78
        self.token_arange_np = np.arange(self.max_num_tokens)
79
80
81
82
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
83
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
84

85
86
87
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
88
            vllm_config.model_config
89
        )
90

91
92
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
93
94
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
95
96
97
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
98

99
100
        self.use_cuda_graph = False

101
102
103
        self.compilation_config = self.vllm_config.compilation_config
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            cudagraph_mode = self.compilation_config.cudagraph_mode
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )

118
        # persistent buffers for cuda graph
119
120
121
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
122
123
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
124
125
126
127
128
129
130
131
132
133
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
134
            self.mrope_positions = torch.zeros(
135
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
136
            )
137
138
        else:
            # RoPE need (max_num_tokens,)
139
140
141
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
142
        self.hidden_states = torch.zeros(
143
144
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
145

146
147
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
148
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
149
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
150
151
152
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
153

154
        self.inputs_embeds = torch.zeros(
155
156
157
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
158
        )
159

160
161
162
163
164
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
165
166
            with_numpy=True,
        )
167

168
        # Determine allowed attention backends once during initialization.
169
        self.allowed_attn_types: tuple | None = None
170
171
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
172
173
174
175
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
176
                from vllm.v1.attention.backends.rocm_aiter_fa import (
177
178
179
                    AiterFlashAttentionMetadata,
                )

180
                rocm_types.append(AiterFlashAttentionMetadata)
181
182
183
184
185
186

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
            from vllm.v1.attention.backends.mla.common import MLACommonMetadata

            rocm_types.append(MLACommonMetadata)

187
188
            self.allowed_attn_types = tuple(rocm_types)

189
190
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
191
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
192
193
194
195
196
197
198
199
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
200
201
202
203
204
205
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
206
207
208
209
210
211
212
213
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

214
215
216
217
218
219
220
221
222
223
224
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

225
226
227
228
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
229
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
230
231
232
233
234
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
235
        last_token_indices: torch.Tensor | None,
236
        common_attn_metadata: CommonAttentionMetadata,
237
        sampling_metadata: SamplingMetadata,
238
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
239
        num_rejected_tokens_gpu: torch.Tensor | None = None,
240
    ) -> torch.Tensor:
241
242
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
243
244
245

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
246

247
248
249
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
250
251
                target_hidden_states
            )
252
            assert target_hidden_states.shape[-1] == self.hidden_size
253
254
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
255
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
256
257
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
258
        self.input_ids[last_token_indices] = next_token_ids
259

260
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
261

262
263
264
265
266
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

267
        attn_metadata = attn_metadata_builder.build_for_drafting(
268
269
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
270
271
272
273
274
275
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
276
277
                )
            )
278
279
        else:
            draft_indexer_metadata = None
280
281
282
283
284
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
285

286
287
288
289
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
290
291
292
293
294
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
            num_tokens_unpadded=num_tokens,
            num_tokens_padded=num_tokens,
        )

295
        cudagraph_runtime_mode = CUDAGraphMode.NONE
296
297
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
298
299
            and num_tokens_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
300
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
301
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
302
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
303
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
304
305
306
307
            num_input_tokens = num_tokens_dp_padded
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

308
        # copy inputs to buffer for cudagraph
309
        self._set_positions(num_tokens, target_positions)
310
        self.hidden_states[:num_tokens] = target_hidden_states
311
312
313
314

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

315
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
316
317
318
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
319
            )
320

321
            input_ids = None
322
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
323
324
        else:
            input_ids = self.input_ids[:num_input_tokens]
325
            inputs_embeds = None
326

327
        with set_forward_context(
328
329
330
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
331
            num_tokens_across_dp=num_tokens_across_dp,
332
            cudagraph_runtime_mode=cudagraph_runtime_mode,
333
        ):
Jiayi Yao's avatar
Jiayi Yao committed
334
            ret_hidden_states = self.model(
335
                input_ids=input_ids,
336
                positions=self._get_positions(num_input_tokens),
337
338
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
339
            )
340
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
341
                last_hidden_states = ret_hidden_states
342
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
343
344
            else:
                last_hidden_states, hidden_states = ret_hidden_states
345
        sample_hidden_states = last_hidden_states[last_token_indices]
346
        logits = self.model.compute_logits(sample_hidden_states)
347
348
349
350
351
352

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

353
354
355
356
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
357
358
359
360
361
362
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
363
364
365
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
366
367
368

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
369
370
371
372
373
374
375
376
377
378
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

379
        draft_token_ids = logits.argmax(dim=-1)
380

381
382
383
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
384
385
386
387
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
388
389
                f"{self.allowed_attn_types}"
            )
390

391
392
393
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

Rémi Delacourt's avatar
Rémi Delacourt committed
394
395
396
397
398
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
            num_tokens_unpadded=batch_size,
            num_tokens_padded=batch_size,
        )

399
400
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
401
402
            and batch_size_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
403
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
404
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
405
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
406
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
407
            input_batch_size = batch_size_dp_padded
408
            cudagraph_runtime_mode = CUDAGraphMode.NONE
Rémi Delacourt's avatar
Rémi Delacourt committed
409
410
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
411
412
413

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
414
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
415
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
416
417
            self.token_arange_np[: batch_size + 1]
        ).clone()
418
419
420
421
422
423
424
425
426
427
428

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

429
        for token_index in range(self.num_speculative_tokens - 1):
430
            # Update the inputs.
431
432
433
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
434
435
436
437
438
439
440
441
442
443
444
445
446
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
447
448
449
450
451
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
452
453
454
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
455
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
456
457
458
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
459
            # Increment the sequence lengths.
460
            common_attn_metadata.seq_lens += 1
461
462
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
463
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
464

465
466
467
468
469
470
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
471

472
            # Compute the slot mapping.
473
474
475
476
477
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
478
            block_ids = common_attn_metadata.block_table_tensor.gather(
479
480
                dim=1, index=block_numbers.view(-1, 1)
            )
481
            block_ids = block_ids.view(-1)
482
483
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
484
485
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
486
487
            else:
                common_attn_metadata.slot_mapping = (
488
489
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
490
491
492
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
493
            common_attn_metadata.slot_mapping.masked_fill_(
494
495
                exceeds_max_model_len, PADDING_SLOT_ID
            )
496
497

            # Rebuild attention metadata
498
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
499
500
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
501
502
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
503

504
505
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
506
            self._set_positions(batch_size, clamped_positions)
507
            self.hidden_states[:batch_size] = hidden_states
508
            if self.supports_mm_inputs:
509
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
510

511
                input_ids = None
512
                inputs_embeds = self.inputs_embeds[:input_batch_size]
513
514
            else:
                input_ids = self.input_ids[:input_batch_size]
515
                inputs_embeds = None
516

517
            # Run the model.
518
            with set_forward_context(
519
520
521
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
522
                num_tokens_across_dp=batch_size_across_dp,
523
                cudagraph_runtime_mode=cudagraph_runtime_mode,
524
            ):
525
                ret_hidden_states = self.model(
526
                    input_ids=input_ids,
527
                    positions=self._get_positions(input_batch_size),
528
529
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
530
                )
531
                if self.method == "mtp":
532
533
534
535
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
536
            hidden_states = hidden_states[:batch_size]
537
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
538
            draft_token_ids = logits.argmax(dim=-1)
539
540
541
542
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
543
        return draft_token_ids
544

545
    def prepare_next_token_ids_cpu(
546
        self,
547
        sampled_token_ids: list[list[int]],
548
549
550
551
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
552
553
554
555
556
557
558
559
560
561
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
562
            if token_ids:
563
564
565
566
567
568
569
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
570
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
571
572
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
573
        next_token_ids = torch.tensor(
574
575
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
576
        return next_token_ids
577

578
579
580
581
582
583
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
584
        discard_request_mask: torch.Tensor,
585
    ) -> tuple[torch.Tensor, torch.Tensor]:
586
587
588
589
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
590
591
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
592
593
594
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
595
596
597
598
599
600
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
601
602
            ],
            dtype=np.int32,
603
        )
604
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
605
        backup_tokens_gpu = self.backup_next_token_ids.gpu
606

607
608
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
609

610
611
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
612

613
614
615
        next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
        valid_sampled_tokens_count = torch.empty(
            (batch_size,), dtype=torch.int32, device=device
616
        )
617

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
634
        )
635
636
637

        return next_token_ids, valid_sampled_tokens_count

638
639
640
641
642
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
643
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
644
645
646
647
648
649
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
650
        No blocking CPU operations should be introduced in this function.
651
        """
652
653
654
655
656
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
657
        )
658
659
660
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
661

662
663
664
665
666
667
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
668
            num_rejected_tokens_gpu,
669
            num_reqs,
670
        )
671
672

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
673
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
674
675
676
677
678
679
680

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
681
682
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
683
684
685
686
687
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
688
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
689
            causal=True,
690
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
691
692
        )

693
694
695
696
697
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
698

699
700
701
702
703
704
705
706
707
708
709
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
710
711
712
713
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
714

715
        total_num_drafts = self.cu_drafts_per_level[0]
716
717
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
718
        num_children = self.child_drafts_per_level[0]
719
720
721
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
722
723
724
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
725
726
727
728
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
729
730
731
732
733
734
735
736
737
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
738
739
        # Precompute the draft token positions.
        flattened_draft_positions = (
740
741
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
742
        tree_depth = len(self.cu_drafts_per_level)
743
        for level in range(tree_depth - 1):
744
745
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
746
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
747
748
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
749
            draft_positions = torch.where(
750
751
752
                exceeds_max_model_len,
                0,
                draft_positions,
753
754
            ).view(batch_size, -1)

755
756
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
757
                draft_positions = draft_positions.repeat_interleave(
758
759
                    level_num_drafts, dim=1
                )
760
761
762
763

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
764
765
                    num_children, dim=1
                )
766
767

            # Concatenate the draft tokens, positions, and hidden states.
768
769
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
770
            tree_hidden_states = torch.cat(
771
772
                [tree_hidden_states, draft_hidden_states], dim=1
            )
773
774
775

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
776
            query_len = total_num_drafts
777
778
            common_attn_metadata = replace(
                common_attn_metadata,
779
                query_start_loc=query_len * self.arange[: batch_size + 1],
780
781
782
783
784
785
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
786
                draft_index=level + 1,
787
788
789
790
791
792
793
794
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
795
796
797
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
798
799
800
801
802
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
803
            query_positions = flattened_draft_positions[:, level : level + query_len]
804
            block_numbers = query_positions // self.block_size
805
806
807
808
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
809
810
811
812
813
814
815
816
817
818
819
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
820
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
821

822
823
824
825
            if (
                self.use_cuda_graph
                and num_tokens <= self.compilation_config.max_cudagraph_capture_size
            ):
826
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
827
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
828
829
            else:
                num_input_tokens = num_tokens
830
                cudagraph_runtime_mode = CUDAGraphMode.NONE
831
            # Run the model.
832
            with set_forward_context(
833
834
835
836
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
837
            ):
838
839
840
841
842
843
844
845
846
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
847
848
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
849
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
850
851
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
852
853
854

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
855
856
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
857
858
859
860
861
862

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
863
864
865
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
866
867
868
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
869
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
870
871
872
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

873
    def prepare_inputs(
874
875
        self,
        common_attn_metadata: CommonAttentionMetadata,
876
877
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
878
879
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
880
        This function is used to prepare the inputs for speculative decoding.
881
882
883
884
885
886
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
887
        #       [0, q1, q1 + q2, q1 + q2 + q3]
888
889
890
891
892
893
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
894
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
895
        #  common_attn_metadata.seq_lens{_cpu}:
896
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
897
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
898
899
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
900

901
902
903
904
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
905
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
906

907
908
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
909
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
910
911

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
912
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
913
914
915
916
917
918
919
920
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
921
            dtype=torch.int32,
922
923
            pin_memory=is_pin_memory_available(),
        )
924
925
926
927
928
929
930
931
932
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
933
934
935
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
936
937
938
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
939
940
941
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
942
943
944
945
946
947

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
948
949
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
950
        # Final token indices are:
951
952
953
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
954
        token_indices_np = token_offests + old_query_start_locs_expanded
955
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
956
957

        spec_common_attn_metadata = CommonAttentionMetadata(
958
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
959
960
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
961
962
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
963
964
965
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
966
            max_seq_len=new_seq_lens_cpu.max().item(),
967
968
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
969
            causal=True,
970
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
971
        )
972
973

        return spec_common_attn_metadata, token_indices
974

975
    def get_model_name(self, model: nn.Module) -> str:
976
        if hasattr(model, "module"):  # multi-GPU
977
978
979
            model = model.module
        return model.__class__.__name__

980
    def load_model(self, target_model: nn.Module) -> None:
981
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
982
        target_attn_layer_names = set(
983
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
984
        )
985
986
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
987
988
989
990
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
991

992
        from vllm.compilation.backends import set_model_tag
993

994
        with set_model_tag("eagle_head"):
995
996
997
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
998

999
        draft_attn_layer_names = (
1000
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
1001
1002
1003
1004
1005
1006
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1007
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1008
1009
1010
1011
1012
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1013
1014
1015
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1016
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1017
1018
1019
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1020
1021
                )
            )
1022
1023
        else:
            self.draft_indexer_metadata_builder = None
1024

1025
        if self.supports_mm_inputs:
1026
1027
1028
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1029
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1030
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1031
1032
1033
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1034
1035
                    "falling back to text-only mode"
                )
1036
                self.supports_mm_inputs = False
1037

1038
1039
        if supports_multimodal(target_model):
            # handle multimodality
1040
1041
1042
1043
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
            ]:
1044
                self.model.config.image_token_index = target_model.config.image_token_id
1045
1046
1047
1048
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1049
1050
            else:
                self.model.config.image_token_index = (
1051
1052
                    target_model.config.image_token_index
                )
1053
1054
1055
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1056

1057
        # share embed_tokens with the target model if needed
1058
        if get_pp_group().world_size == 1:
1059
            if hasattr(target_language_model.model, "embed_tokens"):
1060
                target_embed_tokens = target_language_model.model.embed_tokens
1061
            elif hasattr(target_language_model.model, "embedding"):
1062
1063
1064
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1065
1066
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1067

1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1081
1082
1083
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1084
1085
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1099
            else:
1100
1101
                # MTP model
                share_embeddings = True
1102
                logger.info(
1103
1104
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1105
                )
1106
1107
1108
1109
1110

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1111
        else:
1112
            logger.info(
1113
                "The draft model's vocab embedding will be loaded separately"
1114
1115
                " from the target model."
            )
1116
1117

        # share lm_head with the target model if needed
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1131
1132
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1133
                and torch.equal(
1134
1135
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1136
                )
1137
            ):
1138
                share_lm_head = True
1139
                logger.info(
1140
1141
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1142
                )
1143
1144
            else:
                logger.info(
1145
1146
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1147
                )
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1160

1161
1162
1163
1164
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1165
        use_cudagraphs=True,
Rémi Delacourt's avatar
Rémi Delacourt committed
1166
        is_graph_capturing=False,
1167
    ) -> None:
1168
1169
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1170

Rémi Delacourt's avatar
Rémi Delacourt committed
1171
1172
1173
1174
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1175
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    num_tokens_padded=num_tokens,
                )
                if (
                    cudagraphs_enabled
                    and num_tokens_dp_padded
                    <= self.compilation_config.max_cudagraph_capture_size
                ):
                    num_input_tokens = self.vllm_config.pad_for_cudagraph(
                        num_tokens_dp_padded
                    )
                else:
                    num_input_tokens = num_tokens_dp_padded
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1193

Rémi Delacourt's avatar
Rémi Delacourt committed
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
                if cudagraphs_enabled
                else CUDAGraphMode.NONE,
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

                self.model(
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=inputs_embeds,
                )
1216

1217
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1218
        """Find and return the attention metadata builders for EAGLE layers.
1219

1220
1221
        Returns:
            The metadata builders for EAGLE layers.
1222

1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1238
1239
            "Failed to find attention metadata builder for EAGLE layers."
        )
1240
1241
        return builder

1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1258
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1280

Rémi Delacourt's avatar
Rémi Delacourt committed
1281
1282
1283
1284
1285
1286
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1287
        should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1288
1289
1290
1291
1292
1293
1294
1295
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
            allow_dp_padding=self.use_cuda_graph,
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1296
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1297
1298
1299
1300
1301
1302

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp

1303

1304
1305
1306
1307
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1321
1322
1323
1324
1325
1326
1327
1328
1329
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1341
1342
1343
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1344
1345
1346
1347
1348
1349
1350
1351
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs