eagle.py 57.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.config import (
13
    CompilationMode,
14
15
16
17
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
18
from vllm.distributed.parallel_state import get_pp_group
19
from vllm.forward_context import set_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.model_loader import get_model
23
from vllm.model_executor.models import supports_multimodal
24
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.platforms import current_platform
28
from vllm.triton_utils import triton
29
from vllm.utils.platform_utils import is_pin_memory_available
30
31
32
33
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
34
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
35
36
37
38
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
39
from vllm.v1.kv_cache_interface import KVCacheConfig
40
from vllm.v1.sample.metadata import SamplingMetadata
41
from vllm.v1.sample.sampler import _SAMPLING_EPS
42
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
43
44
45
46
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
47
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
48
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
49
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
50

51
52
logger = init_logger(__name__)

53
54
PADDING_SLOT_ID = -1

55
56
57
58
59
60

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
61
        runner=None,
62
63
    ):
        self.vllm_config = vllm_config
64
        self.speculative_config = vllm_config.speculative_config
65
        assert self.speculative_config is not None
66
67
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
68

Jiayi Yao's avatar
Jiayi Yao committed
69
        self.runner = runner
70
        self.device = device
71
        self.dtype = vllm_config.model_config.dtype
72
        self.max_model_len = vllm_config.model_config.max_model_len
Rémi Delacourt's avatar
Rémi Delacourt committed
73
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
74
75
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
76
        self.token_arange_np = np.arange(self.max_num_tokens)
77
78
79
80
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
81
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
82

83
84
85
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
86
            vllm_config.model_config
87
        )
88

89
90
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
91
92
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
93
94
95
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
96

97
98
        self.use_cuda_graph = False

99
100
101
        self.compilation_config = self.vllm_config.compilation_config
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            cudagraph_mode = self.compilation_config.cudagraph_mode
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )

116
        # persistent buffers for cuda graph
117
118
119
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
120
121
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
122
123
124
125
126
127
128
129
130
131
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
132
            self.mrope_positions = torch.zeros(
133
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
134
            )
135
136
        else:
            # RoPE need (max_num_tokens,)
137
138
139
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
140
        self.hidden_states = torch.zeros(
141
142
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
143

144
145
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
146
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
147
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
148
149
150
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
151

152
        self.inputs_embeds = torch.zeros(
153
154
155
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
156
        )
157

158
159
160
161
162
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
163
164
            with_numpy=True,
        )
165

166
        # Determine allowed attention backends once during initialization.
167
        self.allowed_attn_types: tuple | None = None
168
        if current_platform.is_rocm():
169
170
171
172
173
174
            from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata

            rocm_types = [
                TritonAttentionMetadata,
                RocmAttentionMetadata,
            ]
175
176
177
178
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
179
                from vllm.v1.attention.backends.rocm_aiter_fa import (
180
181
182
                    AiterFlashAttentionMetadata,
                )

183
                rocm_types.append(AiterFlashAttentionMetadata)
184
185
186
187
188
189

            # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
            from vllm.v1.attention.backends.mla.common import MLACommonMetadata

            rocm_types.append(MLACommonMetadata)

190
191
            self.allowed_attn_types = tuple(rocm_types)

192
193
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
194
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
195
196
197
198
199
200
201
202
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
203
204
205
206
207
208
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
209
210
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
211
            1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
212
213
        ).repeat(max_batch_size, 1)

214
215
216
217
218
219
220
221
222
223
224
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

225
226
227
228
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
229
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
230
231
232
233
234
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
235
        last_token_indices: torch.Tensor | None,
236
        common_attn_metadata: CommonAttentionMetadata,
237
        sampling_metadata: SamplingMetadata,
238
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
239
        num_rejected_tokens_gpu: torch.Tensor | None = None,
240
    ) -> torch.Tensor:
241
242
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
243
244
245

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
246

247
248
249
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
250
251
                target_hidden_states
            )
252
            assert target_hidden_states.shape[-1] == self.hidden_size
253
254
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
255
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
256
257
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
258
        self.input_ids[last_token_indices] = next_token_ids
259

260
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
261

262
263
264
265
266
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

267
        attn_metadata = attn_metadata_builder.build_for_drafting(
268
269
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
270
271
272
273
274
275
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
276
277
                )
            )
278
279
        else:
            draft_indexer_metadata = None
280
281
282
283
284
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
285

286
287
288
289
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
290
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
291
            num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
292
293
        )

294
        cudagraph_runtime_mode = CUDAGraphMode.NONE
295
296
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
297
298
            and num_tokens_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
299
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
300
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
301
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
302
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
303
304
305
306
            num_input_tokens = num_tokens_dp_padded
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

307
        # copy inputs to buffer for cudagraph
308
        self._set_positions(num_tokens, target_positions)
309
        self.hidden_states[:num_tokens] = target_hidden_states
310
311
312
313

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

314
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
315
316
317
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
318
            )
319

320
            input_ids = None
321
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
322
323
        else:
            input_ids = self.input_ids[:num_input_tokens]
324
            inputs_embeds = None
325

326
        with set_forward_context(
327
328
329
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
330
            num_tokens_across_dp=num_tokens_across_dp,
331
            cudagraph_runtime_mode=cudagraph_runtime_mode,
332
        ):
Jiayi Yao's avatar
Jiayi Yao committed
333
            ret_hidden_states = self.model(
334
                input_ids=input_ids,
335
                positions=self._get_positions(num_input_tokens),
336
337
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
338
            )
339
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
340
                last_hidden_states = ret_hidden_states
341
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
342
343
            else:
                last_hidden_states, hidden_states = ret_hidden_states
344
        sample_hidden_states = last_hidden_states[last_token_indices]
345
        logits = self.model.compute_logits(sample_hidden_states)
346
347
348
349
350
351

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

352
353
354
355
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
356
357
358
359
360
361
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
362
363
364
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
365
366
367

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
368
369
370
371
372
373
374
375
376
377
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

378
        draft_token_ids = logits.argmax(dim=-1)
379

380
381
382
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
383
384
385
386
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
387
388
                f"{self.allowed_attn_types}"
            )
389

390
391
392
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

Rémi Delacourt's avatar
Rémi Delacourt committed
393
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
394
            num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
Rémi Delacourt's avatar
Rémi Delacourt committed
395
396
        )

397
398
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
399
400
            and batch_size_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
401
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
402
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
403
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
404
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
405
            input_batch_size = batch_size_dp_padded
406
            cudagraph_runtime_mode = CUDAGraphMode.NONE
Rémi Delacourt's avatar
Rémi Delacourt committed
407
408
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
409
410
411

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
412
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
413
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
414
415
            self.token_arange_np[: batch_size + 1]
        ).clone()
416
417
418
419
420
421
422
423
424
425
426

        # In padded drafter batch, we need to adjust the sequence lengths
        # to remove the "padding" (i.e. rejected tokens).
        # Only apply this adjustment when we have rejected tokens
        # (i.e., not the first proposal).
        if self.num_speculative_tokens > 1 and num_rejected_tokens_gpu is not None:
            common_attn_metadata.seq_lens -= num_rejected_tokens_gpu
            # Invalidate the CPU-side shadows to avoid H<>D sync.
            common_attn_metadata._seq_lens_cpu = None
            common_attn_metadata._num_computed_tokens_cpu = None

427
        for token_index in range(self.num_speculative_tokens - 1):
428
            # Update the inputs.
429
430
431
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
432
433
434
435
436
437
438
439
440
441
442
443
444
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
445
446
447
448
449
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
450
451
452
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
453
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
454
455
456
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
457
            # Increment the sequence lengths.
458
            common_attn_metadata.seq_lens += 1
459
460
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
461
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
462

463
464
465
466
467
468
            # Also update the CPU-side shadow; NOTE: this is hacky and should be
            # removed in when common_attn_metadata.seq_lens_cpu is deprecated.
            if common_attn_metadata._seq_lens_cpu is not None:
                common_attn_metadata._seq_lens_cpu += 1
            if common_attn_metadata._num_computed_tokens_cpu is not None:
                common_attn_metadata._num_computed_tokens_cpu += 1
469

470
            # Compute the slot mapping.
471
            block_size = attn_metadata_builder.kv_cache_spec.block_size
472
473
            if self.uses_mrope:
                # all dimensions of positions are the same
474
                block_numbers = clamped_positions[0] // block_size
475
            else:
476
                block_numbers = clamped_positions // block_size
477
            block_ids = common_attn_metadata.block_table_tensor.gather(
478
479
                dim=1, index=block_numbers.view(-1, 1)
            )
480
            block_ids = block_ids.view(-1)
481
482
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
483
                    block_ids * block_size + clamped_positions[0] % block_size
484
                )
485
486
            else:
                common_attn_metadata.slot_mapping = (
487
                    block_ids * block_size + clamped_positions % block_size
488
                )
489
490
491
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
492
            common_attn_metadata.slot_mapping.masked_fill_(
493
494
                exceeds_max_model_len, PADDING_SLOT_ID
            )
495
496

            # Rebuild attention metadata
497
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
498
499
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
500
501
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
502

503
504
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
505
            self._set_positions(batch_size, clamped_positions)
506
            self.hidden_states[:batch_size] = hidden_states
507
            if self.supports_mm_inputs:
508
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
509

510
                input_ids = None
511
                inputs_embeds = self.inputs_embeds[:input_batch_size]
512
513
            else:
                input_ids = self.input_ids[:input_batch_size]
514
                inputs_embeds = None
515

516
            # Run the model.
517
            with set_forward_context(
518
519
520
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
521
                num_tokens_across_dp=batch_size_across_dp,
522
                cudagraph_runtime_mode=cudagraph_runtime_mode,
523
            ):
524
                ret_hidden_states = self.model(
525
                    input_ids=input_ids,
526
                    positions=self._get_positions(input_batch_size),
527
528
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
529
                )
530
                if self.method == "mtp":
531
532
533
534
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
535
            hidden_states = hidden_states[:batch_size]
536
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
537
            draft_token_ids = logits.argmax(dim=-1)
538
539
540
541
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
542
        return draft_token_ids
543

544
    def prepare_next_token_ids_cpu(
545
        self,
546
        sampled_token_ids: list[list[int]],
547
548
549
550
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
551
552
553
554
555
556
557
558
559
560
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
561
            if token_ids:
562
563
564
565
566
567
568
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
569
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
570
571
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
572
        next_token_ids = torch.tensor(
573
574
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
575
        return next_token_ids
576

577
578
579
580
581
582
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
583
        discard_request_mask: torch.Tensor,
584
    ) -> tuple[torch.Tensor, torch.Tensor]:
585
586
587
588
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
589
590
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
591
592
593
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
594
595
596
597
598
599
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
600
601
            ],
            dtype=np.int32,
602
        )
603
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
604
        backup_tokens_gpu = self.backup_next_token_ids.gpu
605

606
607
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
608

609
610
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
611

612
613
        next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device)
        valid_sampled_tokens_count = next_token_ids.new_empty(batch_size)
614

615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
631
        )
632
633
634

        return next_token_ids, valid_sampled_tokens_count

635
636
637
638
639
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
640
    ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
641
642
643
644
645
646
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
647
        No blocking CPU operations should be introduced in this function.
648
        """
649
650
651
652
653
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
654
        )
655
656
657
        num_rejected_tokens_gpu = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
        )
658

659
660
661
662
663
664
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
665
            num_rejected_tokens_gpu,
666
            num_reqs,
667
        )
668
669

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
670
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
671
672
673
674
675
676
677

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
678
679
            _seq_lens_cpu=common_attn_metadata._seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
680
681
682
683
684
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
685
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
686
            causal=True,
687
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
688
689
        )

690
691
692
693
694
        return (
            spec_common_attn_metadata,
            token_indices_to_sample,
            num_rejected_tokens_gpu,
        )
695

696
697
698
699
700
701
702
703
704
705
706
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
707
708
709
710
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
711

712
        total_num_drafts = self.cu_drafts_per_level[0]
713
714
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
715
        num_children = self.child_drafts_per_level[0]
716
717
718
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
719
720
721
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
722
723
724
725
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
726
727
728
729
730
731
732
733
734
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
735
736
        # Precompute the draft token positions.
        flattened_draft_positions = (
737
738
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
739
        tree_depth = len(self.cu_drafts_per_level)
740
        for level in range(tree_depth - 1):
741
742
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
743
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
744
745
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
746
            draft_positions = torch.where(
747
748
749
                exceeds_max_model_len,
                0,
                draft_positions,
750
751
            ).view(batch_size, -1)

752
753
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
754
                draft_positions = draft_positions.repeat_interleave(
755
756
                    level_num_drafts, dim=1
                )
757
758
759
760

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
761
762
                    num_children, dim=1
                )
763
764

            # Concatenate the draft tokens, positions, and hidden states.
765
766
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
767
            tree_hidden_states = torch.cat(
768
769
                [tree_hidden_states, draft_hidden_states], dim=1
            )
770
771
772

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
773
            query_len = total_num_drafts
774
775
            common_attn_metadata = replace(
                common_attn_metadata,
776
                query_start_loc=query_len * self.arange[: batch_size + 1],
777
778
779
780
781
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
782
                common_attn_metadata=common_attn_metadata, draft_index=level + 1
783
784
785
786
787
788
789
790
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
791
792
793
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
794
795
796
797
798
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
799
            block_size = tree_attn_metadata_builder.kv_cache_spec.block_size
800
            query_positions = flattened_draft_positions[:, level : level + query_len]
801
            block_numbers = query_positions // block_size
802
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
803
            slot_mapping = block_ids * block_size + query_positions % block_size
804
805
806
807
808
809
810
811
812
813
814
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
815
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
816

817
818
819
820
            if (
                self.use_cuda_graph
                and num_tokens <= self.compilation_config.max_cudagraph_capture_size
            ):
821
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
822
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
823
824
            else:
                num_input_tokens = num_tokens
825
                cudagraph_runtime_mode = CUDAGraphMode.NONE
826
            # Run the model.
827
            with set_forward_context(
828
829
830
831
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
832
            ):
833
834
835
836
837
838
839
840
841
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
842
843
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
844
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
845
846
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
847
848
849

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
850
851
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
852
853
854
855
856
857

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
858
859
860
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
861
862
863
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
864
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
865
866
867
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

868
    def prepare_inputs(
869
870
        self,
        common_attn_metadata: CommonAttentionMetadata,
871
872
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
873
874
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
875
        This function is used to prepare the inputs for speculative decoding.
876
877
878
879
880
881
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
882
        #       [0, q1, q1 + q2, q1 + q2 + q3]
883
884
885
886
887
888
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
889
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
890
        #  common_attn_metadata.seq_lens{_cpu}:
891
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
892
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
893
894
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
895

896
897
898
899
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
900
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
901

902
903
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
904
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
905
906

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
907
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
908
909
910
911
912
913
914
915
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
916
            dtype=torch.int32,
917
918
            pin_memory=is_pin_memory_available(),
        )
919
920
921
922
923
924
925
926
927
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
928
929
930
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
931
932
933
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
934
935
936
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
937
938
939
940
941
942

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
943
944
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
945
        # Final token indices are:
946
947
948
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
949
        token_indices_np = token_offests + old_query_start_locs_expanded
950
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
951
952

        spec_common_attn_metadata = CommonAttentionMetadata(
953
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
954
955
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
956
957
            _seq_lens_cpu=new_seq_lens_cpu,
            _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
958
959
960
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
961
            max_seq_len=new_seq_lens_cpu.max().item(),
962
963
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
964
            causal=True,
965
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
966
        )
967
968

        return spec_common_attn_metadata, token_indices
969

970
    def get_model_name(self, model: nn.Module) -> str:
971
        if hasattr(model, "module"):  # multi-GPU
972
973
974
            model = model.module
        return model.__class__.__name__

975
    def load_model(self, target_model: nn.Module) -> None:
976
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
977
        target_attn_layer_names = set(
978
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
979
        )
980
981
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
982
983
984
985
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
986

987
        from vllm.compilation.backends import set_model_tag
988

989
        with set_model_tag("eagle_head"):
990
991
992
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
993

994
        draft_attn_layer_names = (
995
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
996
997
998
999
1000
1001
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
1002
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
1003
1004
1005
1006
1007
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
1008
1009
1010
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
1011
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
1012
1013
1014
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
1015
1016
                )
            )
1017
1018
        else:
            self.draft_indexer_metadata_builder = None
1019

1020
        if self.supports_mm_inputs:
1021
1022
1023
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1024
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1025
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1026
1027
1028
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1029
1030
                    "falling back to text-only mode"
                )
1031
                self.supports_mm_inputs = False
1032

1033
1034
        if supports_multimodal(target_model):
            # handle multimodality
1035
1036
1037
1038
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
            ]:
1039
                self.model.config.image_token_index = target_model.config.image_token_id
1040
1041
1042
1043
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1044
1045
            else:
                self.model.config.image_token_index = (
1046
1047
                    target_model.config.image_token_index
                )
1048
1049
1050
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1051

1052
        # share embed_tokens with the target model if needed
1053
        if get_pp_group().world_size == 1:
1054
            if hasattr(target_language_model.model, "embed_tokens"):
1055
                target_embed_tokens = target_language_model.model.embed_tokens
1056
            elif hasattr(target_language_model.model, "embedding"):
1057
1058
1059
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1060
1061
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1062

1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1076
1077
1078
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1079
1080
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1094
            else:
1095
1096
                # MTP model
                share_embeddings = True
1097
                logger.info(
1098
1099
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1100
                )
1101
1102
1103
1104
1105

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1106
        else:
1107
            logger.info(
1108
                "The draft model's vocab embedding will be loaded separately"
1109
1110
                " from the target model."
            )
1111
1112

        # share lm_head with the target model if needed
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1126
1127
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1128
                and torch.equal(
1129
1130
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1131
                )
1132
            ):
1133
                share_lm_head = True
1134
                logger.info(
1135
1136
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1137
                )
1138
1139
            else:
                logger.info(
1140
1141
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1142
                )
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1155

1156
1157
1158
1159
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1160
1161
        use_cudagraphs: bool = True,
        is_graph_capturing: bool = False,
1162
    ) -> None:
1163
1164
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1165

Rémi Delacourt's avatar
Rémi Delacourt committed
1166
1167
1168
1169
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1170
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1171
1172
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
1173
                    num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
Rémi Delacourt's avatar
Rémi Delacourt committed
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
                )
                if (
                    cudagraphs_enabled
                    and num_tokens_dp_padded
                    <= self.compilation_config.max_cudagraph_capture_size
                ):
                    num_input_tokens = self.vllm_config.pad_for_cudagraph(
                        num_tokens_dp_padded
                    )
                else:
                    num_input_tokens = num_tokens_dp_padded
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1187

Rémi Delacourt's avatar
Rémi Delacourt committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
                if cudagraphs_enabled
                else CUDAGraphMode.NONE,
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

                self.model(
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=inputs_embeds,
                )
1210

1211
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1212
        """Find and return the attention metadata builders for EAGLE layers.
1213

1214
1215
        Returns:
            The metadata builders for EAGLE layers.
1216

1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1232
1233
            "Failed to find attention metadata builder for EAGLE layers."
        )
1234
1235
        return builder

1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1252
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1274

Rémi Delacourt's avatar
Rémi Delacourt committed
1275
1276
1277
1278
1279
1280
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1281
        should_ubatch, num_toks_across_dp, _ = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1282
1283
1284
1285
1286
1287
1288
1289
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
            allow_dp_padding=self.use_cuda_graph,
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1290
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1291
1292
1293
1294
1295
1296

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp

1297

1298
1299
1300
1301
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1315
1316
1317
1318
1319
1320
1321
1322
1323
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1335
1336
1337
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1338
1339
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
1340
        next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids)
1341
    return next_token_ids, probs