eagle.py 56.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import ast
from dataclasses import replace
5
from importlib.util import find_spec
6

7
import numpy as np
8
9
10
import torch
import torch.nn as nn

11
from vllm.attention.backends.registry import AttentionBackendEnum
12
from vllm.config import (
13
    CompilationMode,
14
15
16
17
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
)
18
from vllm.distributed.parallel_state import get_pp_group
19
from vllm.forward_context import set_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.model_loader import get_model
23
from vllm.model_executor.models import supports_multimodal
24
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
25
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.platforms import current_platform
28
from vllm.triton_utils import triton
29
from vllm.utils.platform_utils import is_pin_memory_available
30
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
31
32
33
34
from vllm.v1.attention.backends.tree_attn import (
    TreeAttentionMetadata,
    TreeAttentionMetadataBuilder,
)
35
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
36
37
38
39
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
)
40
from vllm.v1.kv_cache_interface import KVCacheConfig
41
from vllm.v1.sample.metadata import SamplingMetadata
42
from vllm.v1.sample.sampler import _SAMPLING_EPS
43
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
44
45
46
47
from vllm.v1.spec_decode.utils import (
    eagle_prepare_inputs_padded_kernel,
    eagle_prepare_next_token_padded_kernel,
)
48
from vllm.v1.utils import CpuGpuBuffer
Rémi Delacourt's avatar
Rémi Delacourt committed
49
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
50
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
51

52
53
logger = init_logger(__name__)

54
55
PADDING_SLOT_ID = -1

56
57
58
59
60
61

class EagleProposer:
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
Jiayi Yao's avatar
Jiayi Yao committed
62
        runner=None,
63
64
    ):
        self.vllm_config = vllm_config
65
        self.speculative_config = vllm_config.speculative_config
66
        assert self.speculative_config is not None
67
68
        self.draft_model_config = self.speculative_config.draft_model_config
        self.method = self.speculative_config.method
69

Jiayi Yao's avatar
Jiayi Yao committed
70
        self.runner = runner
71
        self.device = device
72
        self.dtype = vllm_config.model_config.dtype
73
74
        self.max_model_len = vllm_config.model_config.max_model_len
        self.block_size = vllm_config.cache_config.block_size
Rémi Delacourt's avatar
Rémi Delacourt committed
75
        self.dp_rank = vllm_config.parallel_config.data_parallel_rank
76
77
        self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
        self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
78
        self.token_arange_np = np.arange(self.max_num_tokens)
79
80
81
82
        # We need to get the hidden size from the draft model config because
        # the draft model's hidden size can be different from the target model's
        # hidden size (e.g., Llama 3.3 70B).
        self.hidden_size = self.draft_model_config.get_hidden_size()
83
        self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
84

85
86
87
        # Multi-modal data support
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
88
            vllm_config.model_config
89
        )
90

91
92
        self.attn_metadata_builder: AttentionMetadataBuilder | None = None
        self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
93
94
        self.attn_layer_names: list[str] = []
        self.indexer_layer_names: list[str] = []
95
96
97
        self.eagle3_use_aux_hidden_state: bool = (
            self._get_eagle3_use_aux_hidden_state_from_config()
        )
98

99
100
        self.use_cuda_graph = False

101
102
103
        self.compilation_config = self.vllm_config.compilation_config
        if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
            cudagraph_mode = self.compilation_config.cudagraph_mode
104
105
106
107
108
109
110
111
112
113
114
115
116
117
            if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
                CUDAGraphMode.PIECEWISE
            ):
                logger.warning(
                    "Currently the eagle proposer only supports cudagraph_mode "
                    "PIECEWISE, if you want the drafter to use cuda graphs, "
                    "please set compilation_config.cudagraph_mode to PIECEWISE "
                    "or FULL_AND_PIECEWISE"
                )
            self.use_cuda_graph = (
                cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
                and not self.speculative_config.enforce_eager
            )

118
        # persistent buffers for cuda graph
119
120
121
        self.input_ids = torch.zeros(
            self.max_num_tokens, dtype=torch.int32, device=device
        )
122
123
        self.uses_mrope = self.vllm_config.model_config.uses_mrope
        if self.uses_mrope:
124
125
126
127
128
129
130
131
132
133
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
134
            self.mrope_positions = torch.zeros(
135
                (3, self.max_num_tokens + 1), dtype=torch.int64, device=device
136
            )
137
138
        else:
            # RoPE need (max_num_tokens,)
139
140
141
            self.positions = torch.zeros(
                self.max_num_tokens, dtype=torch.int64, device=device
            )
142
        self.hidden_states = torch.zeros(
143
144
            (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
        )
145

146
147
        # We need +1 here because the arange is used to set query_start_loc,
        # which has one more element than batch_size.
148
        max_batch_size = vllm_config.scheduler_config.max_num_seqs
149
        max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
150
151
152
        self.arange = torch.arange(
            max_num_slots_for_arange, device=device, dtype=torch.int32
        )
153

154
        self.inputs_embeds = torch.zeros(
155
156
157
            (self.max_num_tokens, self.inputs_embeds_size),
            dtype=self.dtype,
            device=device,
158
        )
159

160
161
162
163
164
        self.backup_next_token_ids = CpuGpuBuffer(
            max_batch_size,
            dtype=torch.int32,
            pin_memory=is_pin_memory_available(),
            device=device,
165
166
            with_numpy=True,
        )
167

168
        # Determine allowed attention backends once during initialization.
169
        self.allowed_attn_types: tuple | None = None
170
171
        if current_platform.is_rocm():
            rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
172
173
174
175
            # ROCM_AITER_FA is an optional backend
            if find_spec(
                AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
            ):
176
                from vllm.v1.attention.backends.rocm_aiter_fa import (
177
178
179
                    AiterFlashAttentionMetadata,
                )

180
181
182
                rocm_types.append(AiterFlashAttentionMetadata)
            self.allowed_attn_types = tuple(rocm_types)

183
184
        # Parse the speculative token tree.
        spec_token_tree = self.speculative_config.speculative_token_tree
185
        self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree)
186
187
188
189
190
191
192
193
        tree_depth = len(self.tree_choices[-1])
        # Precompute per-level properties of the tree.
        num_drafts_per_level = [0] * tree_depth
        for node in self.tree_choices:
            num_drafts_per_level[len(node) - 1] += 1
        self.cu_drafts_per_level = [num_drafts_per_level[0]]
        self.child_drafts_per_level = [num_drafts_per_level[0]]
        for level in range(1, tree_depth):
194
195
196
197
198
199
            self.cu_drafts_per_level.append(
                self.cu_drafts_per_level[-1] + num_drafts_per_level[level]
            )
            self.child_drafts_per_level.append(
                num_drafts_per_level[level] // num_drafts_per_level[level - 1]
            )
200
201
202
203
204
205
206
207
        # Precompute draft position offsets in flattened tree.
        self.tree_draft_pos_offsets = torch.arange(
            1,
            len(self.tree_choices) + 1,
            device=device,
            dtype=torch.int32,
        ).repeat(max_batch_size, 1)

208
209
210
211
212
213
214
215
216
217
218
    def _get_positions(self, num_tokens: int):
        if self.uses_mrope:
            return self.mrope_positions[:, :num_tokens]
        return self.positions[:num_tokens]

    def _set_positions(self, num_tokens: int, positions: torch.Tensor):
        if self.uses_mrope:
            self.mrope_positions[:, :num_tokens] = positions
        else:
            self.positions[:num_tokens] = positions

219
220
221
222
    def propose(
        self,
        # [num_tokens]
        target_token_ids: torch.Tensor,
223
        # [num_tokens] or [3, num_tokens] when M-RoPE is enabled
224
225
226
227
228
        target_positions: torch.Tensor,
        # [num_tokens, hidden_size]
        target_hidden_states: torch.Tensor,
        # [batch_size]
        next_token_ids: torch.Tensor,
229
        last_token_indices: torch.Tensor | None,
230
        common_attn_metadata: CommonAttentionMetadata,
231
        sampling_metadata: SamplingMetadata,
232
        mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
233
    ) -> torch.Tensor:
234
235
        num_tokens = target_token_ids.shape[0]
        batch_size = next_token_ids.shape[0]
236
237
238

        if last_token_indices is None:
            last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
239

240
241
242
        if self.method == "eagle3":
            assert isinstance(self.model, Eagle3LlamaForCausalLM)
            target_hidden_states = self.model.combine_hidden_states(
243
244
                target_hidden_states
            )
245
            assert target_hidden_states.shape[-1] == self.hidden_size
246
247
        # Shift the input ids by one token.
        # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
248
        self.input_ids[: num_tokens - 1] = target_token_ids[1:]
249
250
        # Replace the last token with the next token.
        # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
251
        self.input_ids[last_token_indices] = next_token_ids
252

253
        assert self.runner is not None
Jiayi Yao's avatar
Jiayi Yao committed
254

255
256
257
258
259
        if self.attn_metadata_builder is None:
            attn_metadata_builder = self._get_attention_metadata_builder()
        else:
            attn_metadata_builder = self.attn_metadata_builder

260
        attn_metadata = attn_metadata_builder.build_for_drafting(
261
262
            common_attn_metadata=common_attn_metadata, draft_index=0
        )
263
264
265
266
267
268
        # FIXME: support hybrid kv for draft model (remove separate indexer)
        if self.draft_indexer_metadata_builder:
            draft_indexer_metadata = (
                self.draft_indexer_metadata_builder.build_for_drafting(
                    common_attn_metadata=common_attn_metadata,
                    draft_index=0,
269
270
                )
            )
271
272
        else:
            draft_indexer_metadata = None
273
274
275
276
277
        # At this moment, we assume all eagle layers belong to the same KV
        # cache group, thus using the same attention metadata.
        per_layer_attn_metadata = {}
        for layer_name in self.attn_layer_names:
            per_layer_attn_metadata[layer_name] = attn_metadata
278

279
280
281
282
        for layer_name in self.indexer_layer_names:
            assert draft_indexer_metadata is not None
            per_layer_attn_metadata[layer_name] = draft_indexer_metadata

Rémi Delacourt's avatar
Rémi Delacourt committed
283
284
285
286
287
        num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
            num_tokens_unpadded=num_tokens,
            num_tokens_padded=num_tokens,
        )

288
        cudagraph_runtime_mode = CUDAGraphMode.NONE
289
290
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
291
292
            and num_tokens_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
293
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
294
            num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded)
295
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
296
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
297
298
299
300
            num_input_tokens = num_tokens_dp_padded
        if num_tokens_across_dp is not None:
            num_tokens_across_dp[self.dp_rank] = num_input_tokens

301
        # copy inputs to buffer for cudagraph
302
        self._set_positions(num_tokens, target_positions)
303
        self.hidden_states[:num_tokens] = target_hidden_states
304
305
306
307

        if self.supports_mm_inputs:
            mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)

308
            self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
309
310
311
                self.input_ids[:num_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
312
            )
313

314
            input_ids = None
315
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
316
317
        else:
            input_ids = self.input_ids[:num_input_tokens]
318
            inputs_embeds = None
319

320
        with set_forward_context(
321
322
323
            per_layer_attn_metadata,
            self.vllm_config,
            num_tokens=num_input_tokens,
Rémi Delacourt's avatar
Rémi Delacourt committed
324
            num_tokens_across_dp=num_tokens_across_dp,
325
            cudagraph_runtime_mode=cudagraph_runtime_mode,
326
        ):
Jiayi Yao's avatar
Jiayi Yao committed
327
            ret_hidden_states = self.model(
328
                input_ids=input_ids,
329
                positions=self._get_positions(num_input_tokens),
330
331
                hidden_states=self.hidden_states[:num_input_tokens],
                inputs_embeds=inputs_embeds,
332
            )
333
            if self.method == "mtp":
Jiayi Yao's avatar
Jiayi Yao committed
334
                last_hidden_states = ret_hidden_states
335
                hidden_states = last_hidden_states
Jiayi Yao's avatar
Jiayi Yao committed
336
337
            else:
                last_hidden_states, hidden_states = ret_hidden_states
338
        sample_hidden_states = last_hidden_states[last_token_indices]
339
        logits = self.model.compute_logits(sample_hidden_states)
340
341
342
343
344
345

        # Early exit if there is only one draft token to be generated.
        if self.num_speculative_tokens == 1:
            draft_token_ids = logits.argmax(dim=-1)
            return draft_token_ids.view(-1, 1)

346
347
348
349
        if self.uses_mrope:
            positions = target_positions[:, last_token_indices]
        else:
            positions = target_positions[last_token_indices]
350
351
352
353
354
355
        if self.method in (
            "deepseek_mtp",
            "ernie_mtp",
            "longcat_flash_mtp",
            "pangu_ultra_moe_mtp",
        ):
XuruiYang's avatar
XuruiYang committed
356
357
358
            hidden_states = self.hidden_states[last_token_indices]
        else:
            hidden_states = hidden_states[last_token_indices]
359
360
361

        if isinstance(attn_metadata, TreeAttentionMetadata):
            # Draft using tree attention.
362
363
364
365
366
367
368
369
370
371
            draft_token_ids_list = self.propose_tree(
                batch_size=batch_size,
                logits=logits,
                positions=positions,
                hidden_states=hidden_states,
                common_attn_metadata=common_attn_metadata,
            )
            # [batch_size, num_tree_tokens]
            return torch.cat(draft_token_ids_list, dim=1)

372
        draft_token_ids = logits.argmax(dim=-1)
373

374
375
376
        if self.allowed_attn_types is not None and not isinstance(
            attn_metadata, self.allowed_attn_types
        ):
377
378
379
380
            raise ValueError(
                f"Unsupported attention metadata type for speculative "
                "decoding with num_speculative_tokens > 1: "
                f"{type(attn_metadata)}. Supported types are: "
381
382
                f"{self.allowed_attn_types}"
            )
383

384
385
386
        # Generate the remaining draft tokens.
        draft_token_ids_list = [draft_token_ids]

Rémi Delacourt's avatar
Rémi Delacourt committed
387
388
389
390
391
        batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
            num_tokens_unpadded=batch_size,
            num_tokens_padded=batch_size,
        )

392
393
        if (
            self.use_cuda_graph
Rémi Delacourt's avatar
Rémi Delacourt committed
394
395
            and batch_size_dp_padded
            <= self.compilation_config.max_cudagraph_capture_size
396
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
397
            input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded)
398
            cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
399
        else:
Rémi Delacourt's avatar
Rémi Delacourt committed
400
            input_batch_size = batch_size_dp_padded
401
            cudagraph_runtime_mode = CUDAGraphMode.NONE
Rémi Delacourt's avatar
Rémi Delacourt committed
402
403
        if batch_size_across_dp is not None:
            batch_size_across_dp[self.dp_rank] = input_batch_size
404
405
406

        common_attn_metadata.num_actual_tokens = batch_size
        common_attn_metadata.max_query_len = 1
407
        common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
408
        common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
409
410
            self.token_arange_np[: batch_size + 1]
        ).clone()
411
        for token_index in range(self.num_speculative_tokens - 1):
412
            # Update the inputs.
413
414
415
            # cast to int32 is crucial when eagle model is compiled.
            # tensor.argmax() returns int64 by default.
            input_ids = draft_token_ids_list[-1].int()
416
417
418
419
420
421
422
423
424
425
426
427
428
            if self.uses_mrope:
                positions += 1
                # NOTE(woosuk): We should handle the case where the draft model
                # generates tokens beyond the max model length.
                # Since it is complex to remove such requests from the batch,
                # we keep them in the batch but adjust the position ids
                # and slot mappings to avoid the
                # out-of-range access during the model execution.
                # The draft tokens generated with this adjustment
                # should be ignored.
                exceeds_max_model_len = positions[0] >= self.max_model_len
                # Mask out the position ids that exceed the max model length.
                # Otherwise, we may get out-of-range error in RoPE.
429
430
431
432
433
                clamped_positions = torch.where(
                    exceeds_max_model_len.unsqueeze(0),
                    torch.zeros_like(positions),
                    positions,
                )
434
435
436
            else:
                positions += 1
                exceeds_max_model_len = positions >= self.max_model_len
437
                clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
438
439
440
            # For data integrity when async scheduling, we shouldn't use in place
            # operations in case they are modified in next step's `prepare_input`
            # of main model.
441
            # Increment the sequence lengths.
442
            common_attn_metadata.seq_lens += 1
443
444
            # This is an out-of-place operation to avoid modifying the original tensor.
            common_attn_metadata.seq_lens_cpu = common_attn_metadata.seq_lens_cpu + 1
445
446
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
447

448
            common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
449

450
            common_attn_metadata.num_computed_tokens_cpu = (
451
                common_attn_metadata.seq_lens_cpu - 1
452
            )
453

454
            # Compute the slot mapping.
455
456
457
458
459
            if self.uses_mrope:
                # all dimensions of positions are the same
                block_numbers = clamped_positions[0] // self.block_size
            else:
                block_numbers = clamped_positions // self.block_size
460
            block_ids = common_attn_metadata.block_table_tensor.gather(
461
462
                dim=1, index=block_numbers.view(-1, 1)
            )
463
            block_ids = block_ids.view(-1)
464
465
            if self.uses_mrope:
                common_attn_metadata.slot_mapping = (
466
467
                    block_ids * self.block_size + clamped_positions[0] % self.block_size
                )
468
469
            else:
                common_attn_metadata.slot_mapping = (
470
471
                    block_ids * self.block_size + clamped_positions % self.block_size
                )
472
473
474
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
475
            common_attn_metadata.slot_mapping.masked_fill_(
476
477
                exceeds_max_model_len, PADDING_SLOT_ID
            )
478
479

            # Rebuild attention metadata
480
            attn_metadata = attn_metadata_builder.build_for_drafting(  # type: ignore
481
482
                common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
            )
483
484
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata
485

486
487
            # copy inputs to buffer for cudagraph
            self.input_ids[:batch_size] = input_ids
488
            self._set_positions(batch_size, clamped_positions)
489
            self.hidden_states[:batch_size] = hidden_states
490
            if self.supports_mm_inputs:
491
                self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
492

493
                input_ids = None
494
                inputs_embeds = self.inputs_embeds[:input_batch_size]
495
496
            else:
                input_ids = self.input_ids[:input_batch_size]
497
                inputs_embeds = None
498

499
            # Run the model.
500
            with set_forward_context(
501
502
503
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=input_batch_size,
Rémi Delacourt's avatar
Rémi Delacourt committed
504
                num_tokens_across_dp=batch_size_across_dp,
505
                cudagraph_runtime_mode=cudagraph_runtime_mode,
506
            ):
507
                ret_hidden_states = self.model(
508
                    input_ids=input_ids,
509
                    positions=self._get_positions(input_batch_size),
510
511
                    hidden_states=self.hidden_states[:input_batch_size],
                    inputs_embeds=inputs_embeds,
512
                )
513
                if self.method == "mtp":
514
515
516
517
                    last_hidden_states = ret_hidden_states
                    hidden_states = ret_hidden_states
                else:
                    last_hidden_states, hidden_states = ret_hidden_states
518
            hidden_states = hidden_states[:batch_size]
519
            logits = self.model.compute_logits(last_hidden_states[:batch_size])
520
            draft_token_ids = logits.argmax(dim=-1)
521
522
523
524
            draft_token_ids_list.append(draft_token_ids)

        # [batch_size, num_speculative_tokens]
        draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
525
        return draft_token_ids
526

527
    def prepare_next_token_ids_cpu(
528
        self,
529
        sampled_token_ids: list[list[int]],
530
531
532
533
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
        num_scheduled_tokens: dict[str, int],
    ) -> torch.Tensor:
534
535
536
537
538
539
540
541
542
543
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids for each request based on the sampled
        token ids from the CPU. If a request has no sampled token ids (e.g.,
        during the initial decoding steps), it falls back to using the request
        state to get the next token id.
        """
        req_ids = gpu_input_batch.req_ids
        next_token_ids: list[int] = []
        for i, token_ids in enumerate(sampled_token_ids):
544
            if token_ids:
545
546
547
548
549
550
551
                # Common case.
                next_token_id = token_ids[-1]
            else:
                # Partial prefill (rare case).
                # Get the next token id from the request state.
                req_id = req_ids[i]
                req_state = requests[req_id]
552
                seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id]
553
554
                next_token_id = req_state.get_token_id(seq_len)
            next_token_ids.append(next_token_id)
555
        next_token_ids = torch.tensor(
556
557
            next_token_ids, dtype=torch.int32, device=self.input_ids.device
        )
558
        return next_token_ids
559

560
561
562
563
564
565
    def prepare_next_token_ids_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        sampled_token_ids: torch.Tensor,
        requests: dict[str, CachedRequestState],
        gpu_input_batch: InputBatch,
566
        discard_request_mask: torch.Tensor,
567
    ) -> tuple[torch.Tensor, torch.Tensor]:
568
569
570
571
        """
        This function is used to prepare the inputs for speculative decoding.
        It calculates the next token ids and the number of valid sampled tokens
        for each request, considering the "discarded" requests whose next token
572
573
        is not sampled and comes from `request.get_token_id()` instead. This is denoted
        the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
574
575
576
        """
        # Precompute get_token_id for when there is no valid next token
        num_reqs = gpu_input_batch.num_reqs
577
578
579
580
581
582
        self.backup_next_token_ids.np[:num_reqs] = np.array(
            [
                requests[gpu_input_batch.req_ids[i]].get_token_id(
                    common_attn_metadata.seq_lens_cpu[i].item()
                )
                for i in range(num_reqs)
583
584
            ],
            dtype=np.int32,
585
        )
586
        self.backup_next_token_ids.copy_to_gpu(num_reqs)
587
        backup_tokens_gpu = self.backup_next_token_ids.gpu
588

589
590
        batch_size, num_tokens = sampled_token_ids.shape
        device = sampled_token_ids.device
591

592
593
        assert discard_request_mask.dtype == torch.bool
        assert backup_tokens_gpu.dtype == torch.int32
594

595
596
597
        next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
        valid_sampled_tokens_count = torch.empty(
            (batch_size,), dtype=torch.int32, device=device
598
        )
599

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
        # Kernel grid: one program per request (row)
        grid = (batch_size,)

        # Find the next power of 2 for block sizes
        BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
        eagle_prepare_next_token_padded_kernel[grid](
            sampled_token_ids,
            discard_request_mask,
            backup_tokens_gpu,
            next_token_ids,
            valid_sampled_tokens_count,
            gpu_input_batch.vocab_size,
            num_tokens,
            batch_size,
            sampled_token_ids.stride(0),
            BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
616
        )
617
618
619

        return next_token_ids, valid_sampled_tokens_count

620
621
622
623
624
    def prepare_inputs_padded(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        spec_decode_metadata: SpecDecodeMetadata,
        valid_sampled_tokens_count: torch.Tensor,
625
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
626
627
628
629
630
631
632
        """
        This function is used to prepare the inputs for speculative decoding
        It updates the common_attn_metadata for speculative decoding,
        but does not consider the rejected tokens. Instead, all tokens
        are included as inputs to the speculator, with the rejected tokens
        used as padding and filtered out later by `token_indices_to_sample`.
        """
633
634
635
636
637
        num_reqs = common_attn_metadata.num_reqs
        device = valid_sampled_tokens_count.device

        token_indices_to_sample = torch.empty(
            (num_reqs,), dtype=torch.int32, device=device
638
        )
639

640
641
642
643
644
645
646
647
        # Kernel grid: one program per request (row)
        grid = (num_reqs,)
        eagle_prepare_inputs_padded_kernel[grid](
            spec_decode_metadata.cu_num_draft_tokens,
            valid_sampled_tokens_count,
            common_attn_metadata.query_start_loc,
            token_indices_to_sample,
            num_reqs,
648
        )
649
650

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
651
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
652
653
654
655
656
657
658
659

        total_num_tokens = query_start_loc_cpu[-1].item()

        spec_common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=common_attn_metadata.query_start_loc,
            seq_lens=common_attn_metadata.seq_lens,
            query_start_loc_cpu=query_start_loc_cpu,
            seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
660
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
661
662
663
664
665
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
            max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
            block_table_tensor=common_attn_metadata.block_table_tensor,
666
            slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
667
            causal=True,
668
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
669
670
        )

671
        return spec_common_attn_metadata, token_indices_to_sample
672

673
674
675
676
677
678
679
680
681
682
683
    def propose_tree(
        self,
        batch_size: int,
        # [num_tokens, vocab_size]
        logits: torch.Tensor,
        # [num_tokens]
        positions: torch.Tensor,
        # [num_tokens, hidden_size]
        hidden_states: torch.Tensor,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> list[torch.Tensor]:
684
685
686
687
        tree_attn_metadata_builder = self.runner.attn_groups[0][
            0
        ].get_metadata_builder()
        assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
688

689
        total_num_drafts = self.cu_drafts_per_level[0]
690
691
        level_num_drafts = total_num_drafts
        # Sample a draft token for each child at the tree root level.
692
        num_children = self.child_drafts_per_level[0]
693
694
695
        if num_children == 1:
            draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
        else:
696
697
698
            draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                batch_size, -1
            )
699
700
701
702
        draft_token_ids_list = [draft_token_ids]
        draft_hidden_states = hidden_states.view(batch_size, 1, -1)

        # Initialize empty tensors for concatenation with the level outputs.
703
704
705
706
707
708
709
710
711
        tree_input_ids = torch.empty(
            0, device=self.input_ids.device, dtype=self.input_ids.dtype
        )
        tree_positions = torch.empty(
            0, device=self.positions.device, dtype=self.positions.dtype
        )
        tree_hidden_states = torch.empty(
            0, device=self.hidden_states.device, dtype=self.hidden_states.dtype
        )
712
713
        # Precompute the draft token positions.
        flattened_draft_positions = (
714
715
            positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :]
        )
716
        tree_depth = len(self.cu_drafts_per_level)
717
        for level in range(tree_depth - 1):
718
719
            # Get draft positions for RoPE.
            draft_positions = positions + (level + 1)
720
            exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len
721
722
            # Mask out the position ids that exceed the max model length.
            # Otherwise, we may get out-of-range error in RoPE.
723
            draft_positions = torch.where(
724
725
726
                exceeds_max_model_len,
                0,
                draft_positions,
727
728
            ).view(batch_size, -1)

729
730
            if level_num_drafts > 1:
                # Repeat the positions for each draft at this level.
731
                draft_positions = draft_positions.repeat_interleave(
732
733
                    level_num_drafts, dim=1
                )
734
735
736
737

            if num_children > 1:
                # Repeat draft hidden states for each child.
                draft_hidden_states = draft_hidden_states.repeat_interleave(
738
739
                    num_children, dim=1
                )
740
741

            # Concatenate the draft tokens, positions, and hidden states.
742
743
            tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1)
            tree_positions = torch.cat([tree_positions, draft_positions], dim=1)
744
            tree_hidden_states = torch.cat(
745
746
                [tree_hidden_states, draft_hidden_states], dim=1
            )
747
748
749

            # Build new attention metadata for the next level of drafts.
            # This is necessary to support tree attention.
750
            query_len = total_num_drafts
751
752
            common_attn_metadata = replace(
                common_attn_metadata,
753
                query_start_loc=query_len * self.arange[: batch_size + 1],
754
755
756
757
758
759
                seq_lens=common_attn_metadata.seq_lens + level_num_drafts,
                num_actual_tokens=batch_size * query_len,
                max_query_len=query_len,
            )
            attn_metadata = tree_attn_metadata_builder.build_for_drafting(
                common_attn_metadata=common_attn_metadata,
760
                draft_index=level + 1,
761
762
763
764
765
766
767
768
            )

            # Apply new attention metadata to all layers.
            per_layer_attn_metadata = {}
            for layer_name in self.attn_layer_names:
                per_layer_attn_metadata[layer_name] = attn_metadata

            # Consider max model length.
769
770
771
            attn_metadata.max_seq_len = min(
                attn_metadata.max_seq_len, self.max_model_len
            )
772
773
774
775
776
            # For the requests that exceed the max model length, we set the
            # sequence length to 1 to minimize their overheads in attention.
            attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)

            # Compute the slot mapping.
777
            query_positions = flattened_draft_positions[:, level : level + query_len]
778
            block_numbers = query_positions // self.block_size
779
780
781
782
            block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers)
            slot_mapping = (
                block_ids * self.block_size + query_positions % self.block_size
            )
783
784
785
786
787
788
789
790
791
792
793
            # Mask out the slot mappings that exceed the max model length.
            # Otherwise, the KV cache will be inadvertently updated with the
            # padding tokens.
            slot_mapping[exceeds_max_model_len] = PADDING_SLOT_ID
            attn_metadata.slot_mapping = slot_mapping.view(-1)

            # Copy inputs to buffer for cudagraph.
            num_tokens = attn_metadata.num_actual_tokens
            input_ids = tree_input_ids.view(-1)
            self.input_ids[:num_tokens] = input_ids
            self.positions[:num_tokens] = tree_positions.view(-1)
794
            self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1)
795

796
797
798
799
            if (
                self.use_cuda_graph
                and num_tokens <= self.compilation_config.max_cudagraph_capture_size
            ):
800
                num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
801
                cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
802
803
            else:
                num_input_tokens = num_tokens
804
                cudagraph_runtime_mode = CUDAGraphMode.NONE
805
            # Run the model.
806
            with set_forward_context(
807
808
809
810
                per_layer_attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
811
            ):
812
813
814
815
816
817
818
819
820
                last_hidden_states, hidden_states = self.model(
                    input_ids=self.input_ids[:num_input_tokens],
                    positions=self.positions[:num_input_tokens],
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=None,
                )

            # Get the output hidden states for the draft tokens.
            draft_hidden_states = hidden_states[:num_tokens].view(
821
822
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
823
            draft_last_hidden_states = last_hidden_states[:num_tokens].view(
824
825
                batch_size, query_len, -1
            )[:, -level_num_drafts:]
826
827
828

            # Get the output logits for the draft tokens.
            logits = self.model.compute_logits(
829
830
                draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1)
            )
831
832
833
834
835
836

            # Sample a draft token for each child at the next tree level.
            num_children = self.child_drafts_per_level[level + 1]
            if num_children == 1:
                draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1)
            else:
837
838
839
                draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view(
                    batch_size, -1
                )
840
841
842
            draft_token_ids_list.append(draft_token_ids)

            # Update the # drafts counters for the next tree level.
843
            level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
844
845
846
            total_num_drafts = self.cu_drafts_per_level[level + 1]
        return draft_token_ids_list

847
    def prepare_inputs(
848
849
        self,
        common_attn_metadata: CommonAttentionMetadata,
850
851
        sampled_token_ids: list[list[int]],
        num_draft_tokens: list[int],
852
853
    ) -> tuple[CommonAttentionMetadata, torch.Tensor]:
        """
854
        This function is used to prepare the inputs for speculative decoding.
855
856
857
858
859
860
        It updates to the common_attn_metadata to account for the rejected
        tokens (and newly sampled tokens). It also returns the token indices
        of the tokens that should be fed to the speculator.
        """
        # E.g.
        #  common_attn_metadata.query_start_loc{_cpu}:
861
        #       [0, q1, q1 + q2, q1 + q2 + q3]
862
863
864
865
866
867
        #  common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
        #  num_rejected_tokens: [n1, n2, n3]
        # This function computes the intermediate values:
        #  num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
        # And returns:
        #  common_attn_metadata.query_start_loc{_cpu}:
868
        #       [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
869
        #  common_attn_metadata.seq_lens{_cpu}:
870
        #       [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
871
        #  token_indices: [0, 1, ..., q1 - n1 - 1,
872
873
        #                 q1, q1 + 1, ..., q1 + q2 - n2 - 1,
        #                 q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
874

875
876
877
878
        num_rejected_tokens = [
            n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
            for i, n in enumerate(num_draft_tokens)
        ]
879
        num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32)
880

881
882
        device = common_attn_metadata.query_start_loc.device
        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
883
        new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
884
885

        # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
886
        new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
887
888
889
890
891
892
893
894
        # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
        new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
        new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()

        # [q1 - n1, q2 - n2, q3 - n3] ->
        # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
        new_query_start_loc_cpu = torch.zeros(
            query_start_loc_cpu.shape,
895
            dtype=torch.int32,
896
897
            pin_memory=is_pin_memory_available(),
        )
898
899
900
901
902
903
904
905
906
        new_query_start_loc_np = new_query_start_loc_cpu.numpy()
        np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])

        total_num_tokens = new_query_start_loc_np[-1]
        # Example assuming num_tokens_per_req_np = [2, 4, 3]
        # this implies that `new_query_start_locs` is:
        # [0, 2, 6, 9] ->
        # [0, 0, 2, 2, 2, 2, 6, 6, 6]
        #  _r1_  ____r2____  ___r3__
907
908
909
        new_query_start_locs_expanded = np.repeat(
            new_query_start_loc_np[:-1], new_num_tokens_per_req_np
        )
910
911
912
        # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
        # [0, 1, 0, 1, 2, 3, 0, 1, 2]
        #  _r1_  ____r2____  ___r3__
913
914
915
        token_offests = (
            self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded
        )
916
917
918
919
920
921

        # Expand starting positions to match token pattern
        # [0, q1, q1 + q2] ->
        # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
        #  _r1_  _____r2_______  ___________r3____________
        old_query_start_locs_expanded = np.repeat(
922
923
            query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np
        )
924
        # Final token indices are:
925
926
927
        # [0, 1,                                // req 1
        #  q1 + 0, q1 + 1, q1 + 2, q1 + 3,       // req 2
        #  q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
928
        token_indices_np = token_offests + old_query_start_locs_expanded
929
        token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True)
930
931

        spec_common_attn_metadata = CommonAttentionMetadata(
932
            query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True),
933
934
935
            seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
            query_start_loc_cpu=new_query_start_loc_cpu,
            seq_lens_cpu=new_seq_lens_cpu,
936
            num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
937
938
939
            num_reqs=common_attn_metadata.num_reqs,
            num_actual_tokens=total_num_tokens,
            max_query_len=new_query_len_per_req.max().item(),
940
            max_seq_len=new_seq_lens_cpu.max().item(),
941
942
            block_table_tensor=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping[token_indices],
943
            causal=True,
944
            dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
945
        )
946
947

        return spec_common_attn_metadata, token_indices
948

949
    def get_model_name(self, model: nn.Module) -> str:
950
        if hasattr(model, "module"):  # multi-GPU
951
952
953
            model = model.module
        return model.__class__.__name__

954
    def load_model(self, target_model: nn.Module) -> None:
955
        draft_model_config = self.vllm_config.speculative_config.draft_model_config
956
        target_attn_layer_names = set(
957
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
958
        )
959
960
        # FIXME: support hybrid kv for draft model
        target_indexer_layer_names = set(
961
962
963
964
            get_layers_from_vllm_config(
                self.vllm_config, DeepseekV32IndexerCache
            ).keys()
        )
965

966
        from vllm.compilation.backends import set_model_tag
967

968
        with set_model_tag("eagle_head"):
969
970
971
            self.model = get_model(
                vllm_config=self.vllm_config, model_config=draft_model_config
            )
972

973
        draft_attn_layer_names = (
974
            get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
975
976
977
978
979
980
            - target_attn_layer_names
        )
        indexer_layers = get_layers_from_vllm_config(
            self.vllm_config, DeepseekV32IndexerCache
        )
        draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
981
        self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
982
983
984
985
986
        self.indexer_layer_names = list(draft_indexer_layer_names)

        if self.indexer_layer_names:
            first_layer = self.indexer_layer_names[0]
            self.draft_indexer_metadata_builder = (
987
988
989
                indexer_layers[first_layer]
                .get_attn_backend()
                .get_builder_cls()(
990
                    indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
991
992
993
                    self.indexer_layer_names,
                    self.vllm_config,
                    self.device,
994
995
                )
            )
996
997
        else:
            self.draft_indexer_metadata_builder = None
998

999
        if self.supports_mm_inputs:
1000
1001
1002
            # Even if the target model is multimodal, we can also use
            # text-only draft models
            try:
1003
                dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device)
1004
                self.model.embed_input_ids(dummy_input_ids, multimodal_embeddings=None)
1005
1006
1007
            except (NotImplementedError, AttributeError, TypeError):
                logger.warning(
                    "Draft model does not support multimodal inputs, "
1008
1009
                    "falling back to text-only mode"
                )
1010
                self.supports_mm_inputs = False
1011

1012
1013
        if supports_multimodal(target_model):
            # handle multimodality
1014
1015
1016
1017
            if self.get_model_name(target_model) in [
                "Qwen2_5_VLForConditionalGeneration",
                "Qwen3VLForConditionalGeneration",
            ]:
1018
                self.model.config.image_token_index = target_model.config.image_token_id
1019
1020
1021
1022
            elif self.get_model_name(target_model) == "PixtralForConditionalGeneration":
                self.model.config.image_token_index = (
                    target_model.config.vision_config.image_token_id
                )
1023
1024
            else:
                self.model.config.image_token_index = (
1025
1026
                    target_model.config.image_token_index
                )
1027
1028
1029
            target_language_model = target_model.get_language_model()
        else:
            target_language_model = target_model
1030

1031
        # share embed_tokens with the target model if needed
1032
        if get_pp_group().world_size == 1:
1033
            if hasattr(target_language_model.model, "embed_tokens"):
1034
                target_embed_tokens = target_language_model.model.embed_tokens
1035
            elif hasattr(target_language_model.model, "embedding"):
1036
1037
1038
                target_embed_tokens = target_language_model.model.embedding
            else:
                raise AttributeError(
1039
1040
                    "Target model does not have 'embed_tokens' or 'embedding' attribute"
                )
1041

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
            share_embeddings = False
            if hasattr(self.model, "has_own_embed_tokens"):
                # EAGLE model
                if not self.model.has_own_embed_tokens:
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model without its own embed_tokens in the"
                        " checkpoint. Sharing target model embedding weights with the"
                        " draft model."
                    )
                elif (
                    isinstance(target_embed_tokens.weight, torch.Tensor)
                    and isinstance(self.model.model.embed_tokens.weight, torch.Tensor)
1055
1056
1057
                    # TODO: Offload to CPU for comparison to avoid extra GPU memory
                    # usage in CI testing environments with limited GPU memory
                    and torch.equal(
1058
1059
                        target_embed_tokens.weight.cpu(),
                        self.model.model.embed_tokens.weight.cpu(),
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
                    )
                ):
                    share_embeddings = True
                    logger.info(
                        "Detected EAGLE model with embed_tokens identical to the target"
                        " model. Sharing target model embedding weights with the draft"
                        " model."
                    )
                else:
                    logger.info(
                        "Detected EAGLE model with distinct embed_tokens weights. "
                        "Keeping separate embedding weights from the target model."
                    )
1073
            else:
1074
1075
                # MTP model
                share_embeddings = True
1076
                logger.info(
1077
1078
                    "Detected MTP model. "
                    "Sharing target model embedding weights with the draft model."
1079
                )
1080
1081
1082
1083
1084

            if share_embeddings:
                if hasattr(self.model.model, "embed_tokens"):
                    del self.model.model.embed_tokens
                self.model.model.embed_tokens = target_embed_tokens
1085
        else:
1086
            logger.info(
1087
                "The draft model's vocab embedding will be loaded separately"
1088
1089
                " from the target model."
            )
1090
1091

        # share lm_head with the target model if needed
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        share_lm_head = False
        if hasattr(self.model, "has_own_lm_head"):
            # EAGLE model
            if not self.model.has_own_lm_head:
                share_lm_head = True
                logger.info(
                    "Detected EAGLE model without its own lm_head in the checkpoint. "
                    "Sharing target model lm_head weights with the draft model."
                )
            elif (
                hasattr(target_language_model, "lm_head")
                and isinstance(target_language_model.lm_head.weight, torch.Tensor)
                and isinstance(self.model.lm_head.weight, torch.Tensor)
1105
1106
                # TODO: Offload to CPU for comparison to avoid extra GPU memory
                # usage in CI testing environments with limited GPU memory
1107
                and torch.equal(
1108
1109
                    target_language_model.lm_head.weight.cpu(),
                    self.model.lm_head.weight.cpu(),
1110
                )
1111
            ):
1112
                share_lm_head = True
1113
                logger.info(
1114
1115
                    "Detected EAGLE model with lm_head identical to the target model. "
                    "Sharing target model lm_head weights with the draft model."
1116
                )
1117
1118
            else:
                logger.info(
1119
1120
                    "Detected EAGLE model with distinct lm_head weights. "
                    "Keeping separate lm_head weights from the target model."
1121
                )
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        else:
            # MTP model
            share_lm_head = True
            logger.info(
                "Detected MTP model. "
                "Sharing target model lm_head weights with the draft model."
            )

        if share_lm_head and hasattr(target_language_model, "lm_head"):
            if hasattr(self.model, "lm_head"):
                del self.model.lm_head
            self.model.lm_head = target_language_model.lm_head
1134

1135
1136
1137
1138
    @torch.inference_mode()
    def dummy_run(
        self,
        num_tokens: int,
1139
        use_cudagraphs=True,
Rémi Delacourt's avatar
Rémi Delacourt committed
1140
        is_graph_capturing=False,
1141
    ) -> None:
1142
1143
        # Determine if CUDA graphs should be used for this run.
        cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph
1144

Rémi Delacourt's avatar
Rémi Delacourt committed
1145
1146
1147
1148
        # FIXME: when using tree-based specdec, adjust number of forward-passes
        # according to the depth of the tree.
        for fwd_idx in range(
            self.num_speculative_tokens if not is_graph_capturing else 1
1149
        ):
Rémi Delacourt's avatar
Rémi Delacourt committed
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
            if fwd_idx <= 1:
                num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
                    num_tokens_unpadded=num_tokens,
                    num_tokens_padded=num_tokens,
                )
                if (
                    cudagraphs_enabled
                    and num_tokens_dp_padded
                    <= self.compilation_config.max_cudagraph_capture_size
                ):
                    num_input_tokens = self.vllm_config.pad_for_cudagraph(
                        num_tokens_dp_padded
                    )
                else:
                    num_input_tokens = num_tokens_dp_padded
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[self.dp_rank] = num_input_tokens
1167

Rémi Delacourt's avatar
Rémi Delacourt committed
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
            with set_forward_context(
                None,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
                if cudagraphs_enabled
                else CUDAGraphMode.NONE,
            ):
                if self.supports_mm_inputs:
                    input_ids = None
                    inputs_embeds = self.inputs_embeds[:num_input_tokens]
                else:
                    input_ids = self.input_ids[:num_input_tokens]
                    inputs_embeds = None

                self.model(
                    input_ids=input_ids,
                    positions=self._get_positions(num_input_tokens),
                    hidden_states=self.hidden_states[:num_input_tokens],
                    inputs_embeds=inputs_embeds,
                )
1190

1191
    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
1192
        """Find and return the attention metadata builders for EAGLE layers.
1193

1194
1195
        Returns:
            The metadata builders for EAGLE layers.
1196

1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
        Raises:
            AssertionError: If no metadata builders are found for EAGLE layers.
        """
        builder = None
        chosen_layer = self.attn_layer_names[0]

        for kv_cache_group in self.runner.attn_groups:
            for attn_group in kv_cache_group:
                if chosen_layer in attn_group.layer_names:
                    builder = attn_group.get_metadata_builder()
                    break
            if builder is not None:
                break

        assert builder is not None, (
1212
1213
            "Failed to find attention metadata builder for EAGLE layers."
        )
1214
1215
        return builder

1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
    def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
        """
        Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
        hidden states and directly uses the last layer output just like eagle1.
        They might indicate this by setting "use_aux_hidden_state" to False
        inside the "eagle_config" dict of their hf_config.
        """
        if self.method != "eagle3":
            return False
        # Assume that eagle3 heads use aux hidden states by default
        use_aux_hidden_state = True
        eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
        if eagle_config is not None:
            use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
        return use_aux_hidden_state

1232
    def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
        """
        Validate that all eagle layers belong to the same KVCacheGroup.
        Need this assumption to ensure all eagle layers can use the
        same AttentionMetadata.
        May extend to multiple AttentionMetadata in the future.
        """
        kv_cache_groups: dict[str, int] = {}
        for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                kv_cache_groups[layer_name] = id
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        assert (
            len(
                set(
                    [
                        kv_cache_groups[layer_name]
                        for layer_name in self.attn_layer_names
                    ]
                )
            )
            == 1
        ), "All eagle layers should belong to the same kv cache group"
1254

Rémi Delacourt's avatar
Rémi Delacourt committed
1255
1256
1257
1258
1259
1260
    def _pad_batch_across_dp(
        self,
        num_tokens_unpadded: int,
        num_tokens_padded: int,
    ) -> tuple[int, torch.Tensor]:
        # TODO(Flechman): support DBO ubatching
1261
        should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
Rémi Delacourt's avatar
Rémi Delacourt committed
1262
1263
1264
1265
1266
1267
1268
1269
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=False,
            allow_dp_padding=self.use_cuda_graph,
            num_tokens_padded=num_tokens_padded,
            uniform_decode=None,
            num_scheduled_tokens_per_request=None,
        )
1270
        assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
Rémi Delacourt's avatar
Rémi Delacourt committed
1271
1272
1273
1274
1275
1276

        num_tokens_dp_padded = num_tokens_padded
        if num_toks_across_dp is not None:
            num_tokens_dp_padded = int(num_toks_across_dp[self.dp_rank].item())
        return num_tokens_dp_padded, num_toks_across_dp

1277

1278
1279
1280
1281
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
# the draft prob tensor.
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
# We should refactor this to reuse the same sampling implementation.
def compute_probs_and_sample_next_token(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
    if sampling_metadata.all_greedy:
        # For greedy requests, draft_probs is not used in rejection sampling.
        # Therefore, we can just return the logits.
        probs = logits
        next_token_ids = logits.argmax(dim=-1)
        return next_token_ids, probs

1295
1296
1297
1298
1299
1300
1301
1302
1303
    assert sampling_metadata.temperature is not None

    # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0)
    # consistent with sampler.py's _SAMPLING_EPS threshold
    temperature = sampling_metadata.temperature
    # Avoid division by zero if there are greedy requests.
    if not sampling_metadata.all_random:
        is_greedy = temperature < _SAMPLING_EPS
        temperature = torch.where(is_greedy, 1.0, temperature)
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
    logits.div_(temperature.view(-1, 1))
    probs = logits.softmax(dim=-1, dtype=torch.float32)

    # NOTE(woosuk): Currently, we ignore most of the sampling parameters in
    # generating the draft tokens. We only use the temperature. While this
    # could degrade the acceptance rate, it does not affect the distribution
    # of the generated tokens after rejection sampling.

    # TODO(woosuk): Consider seeds.
    q = torch.empty_like(probs)
    q.exponential_()
1315
1316
1317
    # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
    # will be used later for rejection sampling.
    next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
1318
1319
1320
1321
1322
1323
1324
1325
    if not sampling_metadata.all_random:
        greedy_token_ids = probs.argmax(dim=-1)
        next_token_ids = torch.where(
            is_greedy,
            greedy_token_ids,
            next_token_ids,
        )
    return next_token_ids, probs